<a href="https://colab.research.google.com/github/arjunparmar/VIRTUON/blob/main/Prashant/GMM_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cp /content/drive/Shareddrives/Virtuon/Pytorch/cp-vton-plus.zip /content/

In [4]:
!unzip -qq cp-vton-plus.zip -d /content/

In [16]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import torch
import time
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision import models
import os
import os.path as osp
import json

In [6]:
class CPDataset(data.Dataset):
    def __init__(self, stage, all_root="cp-vton-plus", data_path = "data", mode="train", radius=5, img_height=256, img_width=192):
        super(CPDataset, self).__init__()

        self.root = all_root

        self.data_root = osp.join(all_root,data_path)

        self.datamode = mode

        self.stage = stage

        self.data_list = "".join([mode, "_pairs.txt"])

        self.fine_height = img_height

        self.fine_width = img_width

        self.radius = radius

        self.data_path = osp.join(all_root,data_path, mode)
        
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        self.transform_1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5))
        ])

        self.transform_2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5), (0.5, 0.5))
        ])

        self.transform_3 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        im_names = []
        c_names = []

        with open(osp.join(self.data_root, self.data_list), 'r') as f:
            for line in f.readlines():
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = c_names

    def name(self):
        return "CPDataset"
    
    def __getitem__(self, index):
        c_name = self.c_names[index]
        im_name = self.im_names[index]
        if self.stage == "GMM":
            c = Image.open(osp.join(self.data_path, 'cloth', c_name))
            cm = Image.open(osp.join(self.data_path, 'cloth-mask', c_name)).convert('L')
        else:
            c = Image.open(osp.join(self.data_path, 'warp-cloth', im_name))
            cm = Image.open(osp.join(self.data_path, 'warp-mask', im_name)).convert('L')
        
        c = self.transform(c)
        cm_array = np.array(cm)
        cm_array = (cm_array >= 128).astype(np.float32)
        cm = torch.from_numpy(cm_array)
        cm.unsqueeze_(0)

        # person image
        im = Image.open(osp.join(self.data_path, 'image', im_name))
        im = self.transform(im)

        
        # LIP labels
        
        # [(0, 0, 0),    # 0=Background
        #  (128, 0, 0),  # 1=Hat
        #  (255, 0, 0),  # 2=Hair
        #  (0, 85, 0),   # 3=Glove
        #  (170, 0, 51),  # 4=SunGlasses
        #  (255, 85, 0),  # 5=UpperClothes
        #  (0, 0, 85),     # 6=Dress
        #  (0, 119, 221),  # 7=Coat
        #  (85, 85, 0),    # 8=Socks
        #  (0, 85, 85),    # 9=Pants
        #  (85, 51, 0),    # 10=Jumpsuits
        #  (52, 86, 128),  # 11=Scarf
        #  (0, 128, 0),    # 12=Skirt
        #  (0, 0, 255),    # 13=Face
        #  (51, 170, 221),  # 14=LeftArm
        #  (0, 255, 255),   # 15=RightArm
        #  (85, 255, 170),  # 16=LeftLeg
        #  (170, 255, 85),  # 17=RightLeg
        #  (255, 255, 0),   # 18=LeftShoe
        #  (255, 170, 0)    # 19=RightShoe
        #  (170, 170, 50)   # 20=Skin/Neck/Chest (Newly added after running dataset_neck_skin_correction.py)
        #  ]
         
        # load parsing image
        parse_name = im_name.replace('.jpg', '.png')
        im_parse = Image.open(osp.join(self.data_path, 'image-parse-new',parse_name)).convert('L')
        parse_array = np.array(im_parse)

        im_mask = Image.open(osp.join(self.data_path, 'image-mask', parse_name)).convert('L')
        mask_array = np.array(im_mask)

        parse_shape = (mask_array > 0).astype(np.float32)

        if self.stage == 'GMM':
            parse_head = (parse_array == 1).astype(np.float32) + (parse_array == 4).astype(np.float32) + (parse_array == 13).astype(np.float32)

        else:
            parse_head = (parse_array == 1).astype(np.float32) + (parse_array == 2).astype(np.float32) + (parse_array == 4).astype(np.float32) + (parse_array == 9).astype(np.float32) + (parse_array == 12).astype(np.float32) + (parse_array == 13).astype(np.float32) + (parse_array == 16).astype(np.float32) + (parse_array == 17).astype(np.float32)  
            
        parse_cloth = (parse_array == 5).astype(np.float32) + (parse_array == 6).astype(np.float32) + (parse_array == 7).astype(np.float32)

        parse_shape_ori = Image.fromarray((parse_shape*255).astype(np.uint8))

        parse_shape = parse_shape_ori.resize((self.fine_width//16, self.fine_height//16), Image.BILINEAR)

        parse_shape = parse_shape.resize((self.fine_width, self.fine_height), Image.BILINEAR)
        
        parse_shape_ori = parse_shape_ori.resize((self.fine_width, self.fine_height), Image.BILINEAR)
        
        shape_ori = self.transform_1(parse_shape_ori)

        shape = self.transform_1(parse_shape)

        phead = torch.from_numpy(parse_head)

        pcm = torch.from_numpy(parse_cloth)

        # Upper Cloth
        im_c = im*pcm + (1 - pcm)
        im_h = im*phead + (1-phead)

        # load pose points
        pose_name = im_name.replace('.jpg', '_keypoints.json')
        with open(osp.join(self.data_path, 'pose', pose_name), 'r') as f:
            pose_label = json.load(f)
            pose_data = pose_label['people'][0]['pose_keypoints']
            pose_data = np.array(pose_data)
            pose_data = pose_data.reshape([-1,3])
        
        point_num = pose_data.shape[0]
        pose_map = torch.zeros(point_num, self.fine_height, self.fine_width)
        
        r = self.radius
        
        im_pose = Image.new('L', (self.fine_width, self.fine_height))
        pose_draw = ImageDraw.Draw(im_pose)

        for i in range(point_num):
            one_map = Image.new('L', (self.fine_width, self.fine_height))
            draw = ImageDraw.Draw(one_map)
            pointx = pose_data[i, 0]
            pointy = pose_data[i, 1]

            if pointx > 1 and pointy > 1:
                draw.rectangle((pointx - r, pointy - r, pointx + r, pointy + r), 'white', 'white')
                pose_draw.rectangle((pointx - r, pointy - r, pointx + r, pointy + r), 'white', 'white')

            one_map = self.transform_1(one_map)
            pose_map[i] = one_map[0]

        im_pose = self.transform_1(im_pose)

        agnostic = torch.cat([shape, im_h, pose_map], 0)

        if self.stage == 'GMM':
            im_g = Image.open(osp.join(self.root, 'grid.png'))
            im_g = self.transform(im_g)
        else:
            im_g = ''
        
        pcm.unsqueeze_(0)
        
        result = {
            'c_name': c_name,
            'im_name': im_name,
            'cloth': c,
            'cloth_mask': cm,
            'image': im,
            'agnostic': agnostic,
            'parse_cloth': im_c,
            'shape': shape,
            'head': im_h,
            'pose_image': im_pose,
            'grid_image': im_g,
            'parse_cloth_mask': pcm,
            'shape_ori': shape_ori,
        }

        return result
    def __len__(self):
        return len(self.im_names)


class CPDataLoader(object):
    def __init__(self, dataset, shuffle=True, batch=4, workers=4):
        super(CPDataLoader, self).__init__()

        if shuffle:
            train_sampler = torch.utils.data.sampler.RandomSampler(dataset)
        else:
            train_sampler = None
        
        self.data_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch, shuffle=(train_sampler is None),
            num_workers=workers, pin_memory=True, sampler=train_sampler
        )
        self.dataset = dataset
        self.data_iter = self.data_loader.__iter__()

    def next_batch(self):
        try:
            batch = self.data_iter.__next__()
        except StopIteration:
            self.data_iter = self.data_loader.__iter__()
            batch = self.data_iter.__next__()
        
        return batch


In [7]:
class FeatureL2Norm(torch.nn.Module):
    def __init__(self):
        super(FeatureL2Norm, self).__init__()

    def forward(self, feature):
        epsilon = 1e-6
        norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) +
                         epsilon, 0.5).unsqueeze(1).expand_as(feature)
        return torch.div(feature, norm)

In [8]:
class FeatureExtraction(nn.Module):
    def __init__(self, input_nc, ngf=64, n_layers=3, use_dropout=False):
        super(FeatureExtraction, self).__init__()

        downconv = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1)

        model = [downconv, nn.ReLU(True), nn.BatchNorm2d(ngf)]

        for i in range(n_layers):
            in_ngf = 2**i * ngf if 2**i * ngf < 512 else 512
            out_ngf = 2**(i+1) * ngf if 2**i * ngf < 512 else 512
            downconv = nn.Conv2d(in_ngf, out_ngf, kernel_size=4, stride=2, padding = 1)
            model.append(downconv)
            model.append(nn.ReLU(True))
            model.append(nn.BatchNorm2d(out_ngf))
        
        model.append(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
        model.append(nn.ReLU(True))
        model.append(nn.BatchNorm2d(512))
        model.append(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
        model.append(nn.ReLU(True))

        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        return self.model(x)



class FeatureCorrelation(nn.Module):
    def __init__(self):
        super(FeatureCorrelation, self).__init__()
    def forward(self, feature_A, feature_B):
        b,c,h,w = feature_A.size()
        
        feature_A = feature_A.transpose(2,3).contiguous().view(b, c, h*w)
        feature_B = feature_B.contiguous().view(b, c, h*w).transpose(1,2)

        feature_mul = torch.bmm(feature_B, feature_A)
        correlation_tensor = feature_mul.view(b, h, w, h*w).transpose(2,3).transpose(1,2)

        return correlation_tensor
        # return feature_mul

class FeatureRegression(nn.Module):
    def __init__(self, input_nc=512, output_dim=6, use_cuda=True):
        super(FeatureRegression, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )
        self.linear = nn.Linear(64 * 4 * 3, output_dim)
        self.tanh = nn.Tanh()
        if use_cuda:
            self.conv.cuda()
            self.tanh.cuda()
            self.linear.cuda()
        
    def forward(self, x):
        x = self.conv(x)
        # x = x.view(x.size(0), -1)
        x = x.reshape(x.shape[0], -1)
        x = self.linear(x)
        x = self.tanh(x)
        return x

In [9]:
class TpsGridGen(nn.Module):
    def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True):
        super(TpsGridGen, self).__init__()
        self.out_h, self.out_w = out_h, out_w
        self.reg_factor = reg_factor
        self.use_cuda = use_cuda

        # create grid in numpy
        self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
        # sampling grid with dim-0 coords (Y)
        self.grid_X, self.grid_Y = np.meshgrid(
            np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
        # grid_X,grid_Y: size [1,H,W,1,1]
        self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
        self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
        if use_cuda:
            self.grid_X = self.grid_X.cuda()
            self.grid_Y = self.grid_Y.cuda()

        # initialize regular grid for control points P_i
        if use_regular_grid:
            axis_coords = np.linspace(-1, 1, grid_size)
            self.N = grid_size*grid_size
            P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
            P_X = np.reshape(P_X, (-1, 1))  # size (N,1)
            P_Y = np.reshape(P_Y, (-1, 1))  # size (N,1)
            P_X = torch.FloatTensor(P_X)
            P_Y = torch.FloatTensor(P_Y)
            self.P_X_base = P_X.clone()
            self.P_Y_base = P_Y.clone()
            self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
            self.P_X = P_X.unsqueeze(2).unsqueeze(
                3).unsqueeze(4).transpose(0, 4)
            self.P_Y = P_Y.unsqueeze(2).unsqueeze(
                3).unsqueeze(4).transpose(0, 4)
            if use_cuda:
                self.P_X = self.P_X.cuda()
                self.P_Y = self.P_Y.cuda()
                self.P_X_base = self.P_X_base.cuda()
                self.P_Y_base = self.P_Y_base.cuda()

    def forward(self, theta):
        warped_grid = self.apply_transformation(
            theta, torch.cat((self.grid_X, self.grid_Y), 3))

        return warped_grid

    def compute_L_inverse(self, X, Y):
        N = X.size()[0]  # num of points (along dim 0)
        # construct matrix K
        Xmat = X.expand(N, N)
        Ymat = Y.expand(N, N)
        P_dist_squared = torch.pow(
            Xmat-Xmat.transpose(0, 1), 2)+torch.pow(Ymat-Ymat.transpose(0, 1), 2)
        # make diagonal 1 to avoid NaN in log computation
        P_dist_squared[P_dist_squared == 0] = 1
        K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
        # construct matrix L
        O = torch.FloatTensor(N, 1).fill_(1)
        Z = torch.FloatTensor(3, 3).fill_(0)
        P = torch.cat((O, X, Y), 1)
        L = torch.cat((torch.cat((K, P), 1), torch.cat(
            (P.transpose(0, 1), Z), 1)), 0)
        Li = torch.inverse(L)
        if self.use_cuda:
            Li = Li.cuda()
        return Li

    def apply_transformation(self, theta, points):
        if theta.dim() == 2:
            theta = theta.unsqueeze(2).unsqueeze(3)
        # points should be in the [B,H,W,2] format,
        # where points[:,:,:,0] are the X coords
        # and points[:,:,:,1] are the Y coords

        # input are the corresponding control points P_i
        batch_size = theta.size()[0]
        # split theta into point coordinates
        Q_X = theta[:, :self.N, :, :].squeeze(3)
        Q_Y = theta[:, self.N:, :, :].squeeze(3)
        Q_X = Q_X + self.P_X_base.expand_as(Q_X)
        Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)

        # get spatial dimensions of points
        points_b = points.size()[0]
        points_h = points.size()[1]
        points_w = points.size()[2]

        # repeat pre-defined control points along spatial dimensions of points to be transformed
        P_X = self.P_X.expand((1, points_h, points_w, 1, self.N))
        P_Y = self.P_Y.expand((1, points_h, points_w, 1, self.N))

        # compute weigths for non-linear part
        W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(
            (batch_size, self.N, self.N)), Q_X)
        W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(
            (batch_size, self.N, self.N)), Q_Y)
        # reshape
        # W_X,W,Y: size [B,H,W,1,N]
        W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(
            1, 4).repeat(1, points_h, points_w, 1, 1)
        W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(
            1, 4).repeat(1, points_h, points_w, 1, 1)
        # compute weights for affine part
        A_X = torch.bmm(self.Li[:, self.N:, :self.N].expand(
            (batch_size, 3, self.N)), Q_X)
        A_Y = torch.bmm(self.Li[:, self.N:, :self.N].expand(
            (batch_size, 3, self.N)), Q_Y)
        # reshape
        # A_X,A,Y: size [B,H,W,1,3]
        A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(
            1, 4).repeat(1, points_h, points_w, 1, 1)
        A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(
            1, 4).repeat(1, points_h, points_w, 1, 1)

        # compute distance P_i - (grid_X,grid_Y)
        # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
        points_X_for_summation = points[:, :, :, 0].unsqueeze(
            3).unsqueeze(4).expand(points[:, :, :, 0].size()+(1, self.N))
        points_Y_for_summation = points[:, :, :, 1].unsqueeze(
            3).unsqueeze(4).expand(points[:, :, :, 1].size()+(1, self.N))

        if points_b == 1:
            delta_X = points_X_for_summation-P_X
            delta_Y = points_Y_for_summation-P_Y
        else:
            # use expanded P_X,P_Y in batch dimension
            delta_X = points_X_for_summation - \
                P_X.expand_as(points_X_for_summation)
            delta_Y = points_Y_for_summation - \
                P_Y.expand_as(points_Y_for_summation)

        dist_squared = torch.pow(delta_X, 2)+torch.pow(delta_Y, 2)
        # U: size [1,H,W,1,N]
        dist_squared[dist_squared == 0] = 1  # avoid NaN in log computation
        U = torch.mul(dist_squared, torch.log(dist_squared))

        # expand grid in batch dimension if necessary
        points_X_batch = points[:, :, :, 0].unsqueeze(3)
        points_Y_batch = points[:, :, :, 1].unsqueeze(3)
        if points_b == 1:
            points_X_batch = points_X_batch.expand(
                (batch_size,)+points_X_batch.size()[1:])
            points_Y_batch = points_Y_batch.expand(
                (batch_size,)+points_Y_batch.size()[1:])

        points_X_prime = A_X[:, :, :, :, 0] + \
            torch.mul(A_X[:, :, :, :, 1], points_X_batch) + \
            torch.mul(A_X[:, :, :, :, 2], points_Y_batch) + \
            torch.sum(torch.mul(W_X, U.expand_as(W_X)), 4)

        points_Y_prime = A_Y[:, :, :, :, 0] + \
            torch.mul(A_Y[:, :, :, :, 1], points_X_batch) + \
            torch.mul(A_Y[:, :, :, :, 2], points_Y_batch) + \
            torch.sum(torch.mul(W_Y, U.expand_as(W_Y)), 4)

        return torch.cat((points_X_prime, points_Y_prime), 3)

In [10]:
class DT(nn.Module):
    def __init__(self):
        super(DT, self).__init__()

    def forward(self, x1, x2):
        dt = torch.abs(x1 - x2)
        return dt


class DT2(nn.Module):
    def __init__(self):
        super(DT, self).__init__()

    def forward(self, x1, y1, x2, y2):
        dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
                        torch.mul(y1 - y2, y1 - y2))
        return dt


In [None]:
class AffineGridGen(nn.Module):
    def __init__(self, out_h=256, out_w=192, out_ch=3):
        super(AffineGridGen, self).__init__()
        self.out_h = out_h
        self.out_w = out_w
        self.out_ch = out_ch

    def forward(self, theta):
        theta = theta.contiguous()
        batch_size = theta.size()[0]
        out_size = torch.Size(
            (batch_size, self.out_ch, self.out_h, self.out_w))
        return F.affine_grid(theta, out_size)

In [23]:
class GicLoss(nn.Module):
    def __init__(self, fine_height=256, fine_width=192):
        super(GicLoss, self).__init__()
        self.dT = DT()
        self.fine_height = fine_height
        self.fine_width = fine_width

    def forward(self, grid):
        Gx = grid[:, :, :, 0]
        Gy = grid[:, :, :, 1]
        Gxcenter = Gx[:, 1:self.fine_height - 1, 1:self.fine_width - 1]
        Gxup = Gx[:, 0:self.fine_height - 2, 1:self.fine_width - 1]
        Gxdown = Gx[:, 2:self.fine_height, 1:self.fine_width - 1]
        Gxleft = Gx[:, 1:self.fine_height - 1, 0:self.fine_width - 2]
        Gxright = Gx[:, 1:self.fine_height - 1, 2:self.fine_width]

        Gycenter = Gy[:, 1:self.fine_height - 1, 1:self.fine_width - 1]
        Gyup = Gy[:, 0:self.fine_height - 2, 1:self.fine_width - 1]
        Gydown = Gy[:, 2:self.fine_height, 1:self.fine_width - 1]
        Gyleft = Gy[:, 1:self.fine_height - 1, 0:self.fine_width - 2]
        Gyright = Gy[:, 1:self.fine_height - 1, 2:self.fine_width]

        dtleft = self.dT(Gxleft, Gxcenter)
        dtright = self.dT(Gxright, Gxcenter)
        dtup = self.dT(Gyup, Gycenter)
        dtdown = self.dT(Gydown, Gycenter)

        return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))

In [19]:
class GMM(nn.Module):
    def __init__(self, grid_size = 5, fine_height=256, fine_width=192):
        super(GMM, self).__init__()
        self.extractionA = FeatureExtraction(22, ngf=64, n_layers=3)
        self.extractionB = FeatureExtraction(1, ngf=64, n_layers=3)
        self.l2norm = FeatureL2Norm()
        self.correlation = FeatureCorrelation()
        self.regression = FeatureRegression(input_nc=192, output_dim=2*grid_size**2, use_cuda=True)
        self.gridGen = TpsGridGen(fine_height, fine_width, use_cuda=True, grid_size=grid_size)

    def forward(self, inputA, inputB):
        featureA = self.extractionA(inputA)
        featureB = self.extractionB(inputB)
        featureA = self.l2norm(featureA)
        featureB = self.l2norm(featureB)
        correlation = self.correlation(featureA.cuda(), featureB.cuda())

        theta = self.regression(correlation)
        grid = self.gridGen(theta)
        return grid, theta

In [13]:

def train_gmm(train_loader, model):
    model.cuda()
    model.train()

    # criterion
    criterionL1 = nn.L1Loss()
    gicloss = GicLoss()
    # optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.0001, betas=(0.5, 0.999))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 -
                                                  max(0, step - 100000) / float(100000 + 1))

    for step in range(200000):
        iter_start_time = time.time()
        inputs = train_loader.next_batch()

        im = inputs['image'].cuda()
        im_pose = inputs['pose_image'].cuda()
        im_h = inputs['head'].cuda()
        shape = inputs['shape'].cuda()
        agnostic = inputs['agnostic'].cuda()
        c = inputs['cloth'].cuda()
        cm = inputs['cloth_mask'].cuda()
        im_c = inputs['parse_cloth'].cuda()
        im_g = inputs['grid_image'].cuda()

        grid, theta = model(agnostic, cm)    # can be added c too for new training
        warped_cloth = F.grid_sample(c, grid, padding_mode='border')
        warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
        warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')

        visuals = [[im_h, shape, im_pose],
                   [c, warped_cloth, im_c],
                   [warped_grid, (warped_cloth+im)*0.5, im]]

        # Lwarp = criterionL1(warped_cloth, im_c)    # loss for warped cloth
        Lwarp = criterionL1(warped_mask, cm)    # loss for warped mask thank xuxiaochun025 for fixing the git code.
        # grid regularization loss
        Lgic = gicloss(grid)
        # 200x200 = 40.000 * 0.001
        Lgic = Lgic / (grid.shape[0] * grid.shape[1] * grid.shape[2])

        loss = Lwarp + 40 * Lgic    # total GMM loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if (step+1) % opt.display_count == 0:
        #     board_add_images(board, 'combine', visuals, step+1)
        #     board.add_scalar('loss', loss.item(), step+1)
        #     board.add_scalar('40*Lgic', (40*Lgic).item(), step+1)
        #     board.add_scalar('Lwarp', Lwarp.item(), step+1)
        #     t = time.time() - iter_start_time
        #     print('step: %8d, time: %.3f, loss: %4f, (40*Lgic): %.8f, Lwarp: %.6f' %
        #           (step+1, t, loss.item(), (40*Lgic).item(), Lwarp.item()), flush=True)

        # if (step+1) % opt.save_count == 0:
        #     save_checkpoint(model, os.path.join(
        #         opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))

In [None]:
def main():
  
    # create dataset
    train_dataset = CPDataset("GMM","cp-vton-plus")

    # create dataloader
    train_loader = CPDataLoader(train_dataset)

    model = GMM()
       
    train_gmm(train_loader, model)
       


if __name__ == "__main__":
    main()