In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(512)
        self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(512)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), inplace=True)
        x = F.relu(self.bn2(self.conv2(x)), inplace=True)
        x = F.relu(self.bn3(self.conv3(x)), inplace=True)
        x = F.relu(self.bn4(self.conv4(x)), inplace=True)
        x = F.relu(self.bn5(self.conv5(x)), inplace=True)
        x = F.relu(self.bn6(self.conv6(x)), inplace=True)
        return x


In [None]:
import torch
import torch.nn as nn

class FeatureCorrelation(nn.Module):
    def __init__(self):
        super(FeatureCorrelation, self).__init__()

    def forward(self, feature_A, feature_B):
        b, c, h, w = feature_A.size()

        # reshape features for matrix multiplication
        feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h * w)     # [B, C, HW]
        feature_B = feature_B.view(b, c, h * w).transpose(1, 2)                  # [B, HW, C]

        # perform matrix multiplication
        feature_mul = torch.bmm(feature_B, feature_A)                            # [B, HW, HW]

        # reshape to correlation tensor
        correlation_tensor = feature_mul.view(b, h, w, h * w).transpose(2, 3).transpose(1, 2)  # [B, HW, HW] -> [B, h*w, h, w]

        return correlation_tensor


In [None]:
import torch
import torch.nn as nn

class FeatureRegression(nn.Module):
    def __init__(self):
        super(FeatureRegression, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(192, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.linear = nn.Linear(4 * 3 * 64, 18)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.conv(x)
        x = x.contiguous().view(x.size(0), -1)
        x = self.linear(x)
        x = self.tanh(x)
        return x


In [None]:
class TpsGridGen(nn.Module):
    def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True):
        super(TpsGridGen, self).__init__()
        self.out_h, self.out_w = out_h, out_w
        self.reg_factor = reg_factor
        self.use_cuda = use_cuda

        # create grid in numpy
        self.grid = np.zeros( [self.out_h, self.out_w, 3], dtype=np.float32)
        # sampling grid with dim-0 coords (Y)
        self.grid_X,self.grid_Y = np.meshgrid(np.linspace(-1,1,out_w), np.linspace(-1,1,out_h))
        # grid_X,grid_Y: size [1,H,W,1,1]
        self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
        self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
        if use_cuda:
            self.grid_X = self.grid_X.cuda()
            self.grid_Y = self.grid_Y.cuda()

        # initialize regular grid for control points P_i
        if use_regular_grid:
            axis_coords = np.linspace(-1,1,grid_size)
            self.N = grid_size*grid_size
            P_Y,P_X = np.meshgrid(axis_coords,axis_coords)
            P_X = np.reshape(P_X,(-1,1)) # size (N,1)
            P_Y = np.reshape(P_Y,(-1,1)) # size (N,1)
            P_X = torch.FloatTensor(P_X)
            P_Y = torch.FloatTensor(P_Y)
            self.P_X_base = P_X.clone()
            self.P_Y_base = P_Y.clone()
            self.Li = self.compute_L_inverse(P_X,P_Y).unsqueeze(0)
            self.P_X = P_X.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4)
            self.P_Y = P_Y.unsqueeze(2).unsqueeze(3).unsqueeze(4).transpose(0,4)
            if use_cuda:
                self.P_X = self.P_X.cuda()
                self.P_Y = self.P_Y.cuda()
                self.P_X_base = self.P_X_base.cuda()
                self.P_Y_base = self.P_Y_base.cuda()

    def forward(self, theta):
        warped_grid = self.apply_transformation(theta,torch.cat((self.grid_X,self.grid_Y),3))
        return warped_grid

    def compute_L_inverse(self,X,Y):
        N = X.size()[0] # num of points (along dim 0)
        # construct matrix K
        Xmat = X.expand(N,N)
        Ymat = Y.expand(N,N)
        P_dist_squared = torch.pow(Xmat-Xmat.transpose(0,1),2)+torch.pow(Ymat-Ymat.transpose(0,1),2)
        P_dist_squared[P_dist_squared==0]=1 # make diagonal 1 to avoid NaN in log computation
        K = torch.mul(P_dist_squared,torch.log(P_dist_squared))
        # construct matrix L
        O = torch.FloatTensor(N,1).fill_(1)
        Z = torch.FloatTensor(3,3).fill_(0)
        P = torch.cat((O,X,Y),1)
        L = torch.cat((torch.cat((K,P),1),torch.cat((P.transpose(0,1),Z),1)),0)
        Li = torch.inverse(L)
        if self.use_cuda:
            Li = Li.cuda()
        return Li

    def apply_transformation(self,theta,points):
        orig_device = theta.device
        if self.use_cuda:
            theta = theta.to("cuda:0")
        if theta.dim()==2:
            theta = theta.unsqueeze(2).unsqueeze(3)
        # points should be in the [B,H,W,2] format,
        # where points[:,:,:,0] are the X coords  
        # and points[:,:,:,1] are the Y coords  

        # input are the corresponding control points P_i
        batch_size = theta.size()[0]
        # split theta into point coordinates
        Q_X = theta[:,:self.N,:,:].squeeze(3)
        Q_Y = theta[:,self.N:,:,:].squeeze(3)
        Q_X = Q_X + self.P_X_base.expand_as(Q_X)
        Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)

        # get spatial dimensions of points
        points_b = points.size()[0]
        points_h = points.size()[1]
        points_w = points.size()[2]
        # repeat pre-defined control points along spatial dimensions of points to be transformed
        P_X = self.P_X.expand((1,points_h,points_w,1,self.N))
        P_Y = self.P_Y.expand((1,points_h,points_w,1,self.N))

        # compute weigths for non-linear part
        W_X = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_X)
        W_Y = torch.bmm(self.Li[:,:self.N,:self.N].expand((batch_size,self.N,self.N)),Q_Y)
        # reshape
        # W_X,W,Y: size [B,H,W,1,N]
        W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
        W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
        # compute weights for affine part
        A_X = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_X)
        A_Y = torch.bmm(self.Li[:,self.N:,:self.N].expand((batch_size,3,self.N)),Q_Y)
        # reshape
        # A_X,A,Y: size [B,H,W,1,3]
        A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)
        A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1,4).repeat(1,points_h,points_w,1,1)

        # compute distance P_i - (grid_X,grid_Y)
        # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
        points_X_for_summation = points[:,:,:,0].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,0].size()+(1,self.N))
        points_Y_for_summation = points[:,:,:,1].unsqueeze(3).unsqueeze(4).expand(points[:,:,:,1].size()+(1,self.N))

        if points_b==1:
            delta_X = points_X_for_summation-P_X
            delta_Y = points_Y_for_summation-P_Y
        else:
            # use expanded P_X,P_Y in batch dimension
            delta_X = points_X_for_summation-P_X.expand_as(points_X_for_summation)
            delta_Y = points_Y_for_summation-P_Y.expand_as(points_Y_for_summation)

        dist_squared = torch.pow(delta_X,2)+torch.pow(delta_Y,2)
        # U: size [1,H,W,1,N]
        dist_squared[dist_squared==0]=1 # avoid NaN in log computation
        U = torch.mul(dist_squared,torch.log(dist_squared)) 

        # expand grid in batch dimension if necessary
        points_X_batch = points[:,:,:,0].unsqueeze(3)
        points_Y_batch = points[:,:,:,1].unsqueeze(3)
        if points_b==1:
            points_X_batch = points_X_batch.expand((batch_size,)+points_X_batch.size()[1:])
            points_Y_batch = points_Y_batch.expand((batch_size,)+points_Y_batch.size()[1:])

        points_X_prime = A_X[:,:,:,:,0]+ \
                       torch.mul(A_X[:,:,:,:,1],points_X_batch) + \
                       torch.mul(A_X[:,:,:,:,2],points_Y_batch) + \
                       torch.sum(torch.mul(W_X,U.expand_as(W_X)),4)

        points_Y_prime = A_Y[:,:,:,:,0]+ \
                       torch.mul(A_Y[:,:,:,:,1],points_X_batch) + \
                       torch.mul(A_Y[:,:,:,:,2],points_Y_batch) + \
                       torch.sum(torch.mul(W_Y,U.expand_as(W_Y)),4)

        points_Y_prime = points_Y_prime.to(orig_device)
        points_X_prime = points_X_prime.to(orig_device)

        return torch.cat((points_X_prime,points_Y_prime),3)


In [None]:
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transforms

class GMMDataset(Dataset):
    def __init__(self, names):
        self.names = names

        self.segm_dir = '/kaggle/input/high-resolution-viton-zalando-dataset/test/image-parse-v3'
        self.person_dir = '/kaggle/input/high-resolution-viton-zalando-dataset/test/image'
        self.cloth_dir = '/kaggle/input/high-resolution-viton-zalando-dataset/test/cloth'
        self.deformed_dir = '/kaggle/input/deformed-cloth-test/deformed_cloth_test_png'

        self.transform = transforms.Compose([
            transforms.Resize((256, 192)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]

        segm = self.transform(Image.open(os.path.join(self.segm_dir, name + '.png')).convert('RGB'))
        person = self.transform(Image.open(os.path.join(self.person_dir, name + '.jpg')).convert('RGB'))
        cloth = self.transform(Image.open(os.path.join(self.cloth_dir, name + '.jpg')).convert('RGB'))
        y = self.transform(Image.open(os.path.join(self.deformed_dir, name + '.png')).convert('RGB'))

        return segm, person, y, cloth


In [None]:
from sklearn.model_selection import train_test_split
import glob
import os

# Get all base names from image dir
image_dir = '/kaggle/input/high-resolution-viton-zalando-dataset/test/image'
all_image_paths = sorted(glob.glob(os.path.join(image_dir, '*.jpg')))
all_names = [os.path.basename(p).split('.')[0] for p in all_image_paths]

# Split: 80% train, 20% val
train_names, val_names = train_test_split(all_names, test_size=0.2, random_state=42)

In [None]:
from torch.utils.data import DataLoader

train_dataset = GMMDataset(train_names)
val_dataset = GMMDataset(val_names)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=2)


In [None]:
import matplotlib.pyplot as plt
import random
import torchvision.utils as vutils

def denorm(tensor):
    # Convert from [-1, 1] → [0, 1]
    return tensor * 0.5 + 0.5

def show_batch(dataloader, title="Batch Samples"):
    batch = next(iter(dataloader))  # get first batch
    segm, person, y, cloth = batch

    # Pick a random index from batch
    idx = random.randint(0, segm.size(0) - 1)

    fig, axs = plt.subplots(1, 4, figsize=(8, 3))
    axs[0].imshow(denorm(segm[idx]).permute(1, 2, 0).cpu().numpy())
    axs[0].set_title("Segmentation Map (A)")

    axs[1].imshow(denorm(person[idx]).permute(1, 2, 0).cpu().numpy())
    axs[1].set_title("Person (B)")

    axs[2].imshow(denorm(cloth[idx]).permute(1, 2, 0).cpu().numpy())
    axs[2].set_title("Original Cloth (input)")

    axs[3].imshow(denorm(y[idx]).permute(1, 2, 0).cpu().numpy())
    axs[3].set_title("Deformed Cloth (Ground Truth)")

    for ax in axs:
        ax.axis("off")

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# Example usage
show_batch(train_loader, "Random Training Batch Sample")
show_batch(val_loader, "Random Validation Batch Sample")


In [None]:
def GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor):
    # Step 1: Extract features
    featureA = feature_extractor(segm)   # [B, C, H, W]
    featureB = feature_extractor(person) # [B, C, H, W]

    #  Step 2: Compute correlation
    correlation = feature_correlation(featureA, featureB)  # [B, H*W, H, W]

    #  Step 3: Predict TPS parameters
    theta = feature_regressor(correlation)  # [B, 2N]

    #  Step 4: Generate TPS grid
    warped_grid = tps(theta)  # [B, H, W, 2]

    # Step 5: Warp cloth
    warped_cloth = F.grid_sample(cloth, warped_grid, padding_mode='border', align_corners=True)

    return warped_cloth


In [None]:
import torchvision.models as models
import torch.nn as nn

class Vgg19(nn.Module):
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = nn.Sequential()
        self.slice2 = nn.Sequential()
        self.slice3 = nn.Sequential()
        self.slice4 = nn.Sequential()
        self.slice5 = nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


In [None]:
import torch
import torch.nn as nn

class PerceptualLoss(nn.Module):
    def __init__(self, vgg_model, weights=None):
        super(PerceptualLoss, self).__init__()
        self.vgg = vgg_model
        self.weights = weights or [1.0, 1.0, 1.0, 1.0, 1.0]
        self.criterion = nn.L1Loss()

    def forward(self, x, y):
        # Assume x and y are already in correct VGG input range (e.g., [0, 1])
        x_vgg = self.vgg(x)
        y_vgg = self.vgg(y)

        loss = 0.0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i])
        return loss


In [None]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

l1_loss_fn = nn.L1Loss()
vgg_model = Vgg19().to(device)
perceptual_loss_fn = PerceptualLoss(vgg_model).to(device)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

# -----------------------
# ✅ Setup
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ✅ Instantiate and move submodules to device
feature_extractor = FeatureExtractor().to(device)
feature_correlation = FeatureCorrelation().to(device)
feature_regressor = FeatureRegression().to(device)

# ✅ Initialize perceptual model
vgg_model = Vgg19().to(device)
perceptual_loss_fn = PerceptualLoss(vgg_model).to(device)
l1_loss_fn = nn.L1Loss()

# ✅ Combined loss function
def combined_loss(output, target, alpha=0.5):
    l1 = l1_loss_fn(output, target)
    perceptual = perceptual_loss_fn(output, target)
    return alpha * l1 + (1 - alpha) * perceptual

# -----------------------
# ✅ Training Config
# -----------------------
num_epochs = 50
patience = 5
early_stop_counter = 0
best_val_loss = float('inf')

# ✅ Optimizer
params = list(feature_extractor.parameters()) + \
         list(feature_correlation.parameters()) + \
         list(feature_regressor.parameters())
optimizer = optim.Adam(params, lr=1e-4)

# ✅ LR Scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

# -----------------------
# ✅ Training Loop
# -----------------------
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # Train mode
    feature_extractor.train()
    feature_correlation.train()
    feature_regressor.train()

    train_loss = 0.0
    for segm, person, y, cloth in train_loader:
        segm, person, y, cloth = segm.to(device), person.to(device), y.to(device), cloth.to(device)

        optimizer.zero_grad()
        output = GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor)

        # Rescale to [0, 1] if needed
        output = (output + 1) / 2
        y = (y + 1) / 2

        loss = combined_loss(output, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * segm.size(0)

    train_loss /= len(train_loader.dataset)
    print(f"Train Loss: {train_loss:.4f}")

    # Eval mode
    feature_extractor.eval()
    feature_correlation.eval()
    feature_regressor.eval()

    val_loss = 0.0
    with torch.no_grad():
        for segm, person, y, cloth in val_loader:
            segm, person, y, cloth = segm.to(device), person.to(device), y.to(device), cloth.to(device)

            output = GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor)
            output = (output + 1) / 2
            y = (y + 1) / 2

            loss = combined_loss(output, y)
            val_loss += loss.item() * segm.size(0)

    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}")

    # Step scheduler
    scheduler.step(val_loss)

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0
        print("Validation loss improved. Saving model...")

        torch.save({
            'feature_extractor': feature_extractor.state_dict(),
            'feature_correlation': feature_correlation.state_dict(),
            'feature_regressor': feature_regressor.state_dict(),
        }, 'gmm_model.pth')
    else:
        early_stop_counter += 1
        print(f" No improvement. Early stopping counter: {early_stop_counter}/{patience}")

        if early_stop_counter >= patience:
            print(" Early stopping triggered.")
            break

In [None]:
import matplotlib.pyplot as plt
import torch

def visualize_predictions(feature_extractor, feature_correlation, feature_regressor, dataloader, num_samples=4):
    feature_extractor.eval()
    feature_correlation.eval()
    feature_regressor.eval()

    # Load from best checkpoint
    checkpoint = torch.load('gmm_model.pth', map_location=device)
    feature_extractor.load_state_dict(checkpoint['feature_extractor'])
    feature_correlation.load_state_dict(checkpoint['feature_correlation'])
    feature_regressor.load_state_dict(checkpoint['feature_regressor'])

    with torch.no_grad():
        batch = next(iter(dataloader))
        segm, person, y, cloth = batch
        segm, person, y, cloth = segm.to(device), person.to(device), y.to(device), cloth.to(device)

        # Get predicted warped cloth
        preds = GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor)

        # Denormalize
        segm = segm * 0.5 + 0.5
        person = person * 0.5 + 0.5
        cloth = cloth * 0.5 + 0.5
        y = y * 0.5 + 0.5
        preds = preds * 0.5 + 0.5

        for i in range(min(num_samples, segm.size(0))):
            fig, axs = plt.subplots(1, 5, figsize=(15, 4))

            axs[0].imshow(segm[i].permute(1, 2, 0).cpu().numpy())
            axs[0].set_title("Segmentation")

            axs[1].imshow(person[i].permute(1, 2, 0).cpu().numpy())
            axs[1].set_title("Person")

            axs[2].imshow(cloth[i].permute(1, 2, 0).cpu().numpy())
            axs[2].set_title("Original Cloth")

            axs[3].imshow(preds[i].permute(1, 2, 0).cpu().numpy())
            axs[3].set_title("Warped Cloth (Predicted)")

            axs[4].imshow(y[i].permute(1, 2, 0).cpu().numpy())
            axs[4].set_title("Ground Truth (Deformed)")

            for ax in axs:
                ax.axis('off')
            plt.tight_layout()
            plt.show()

In [None]:
# After training is done
print("\n🔍 Visualizing Predictions from Validation Set:")
visualize_predictions(feature_extractor, feature_correlation, feature_regressor, val_loader, num_samples=4)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Load model weights
checkpoint = torch.load('gmm_model.pth', map_location=device)

feature_extractor.load_state_dict(checkpoint['feature_extractor'])
feature_correlation.load_state_dict(checkpoint['feature_correlation'])
feature_regressor.load_state_dict(checkpoint['feature_regressor'])

feature_extractor.to(device)
feature_correlation.to(device)
feature_regressor.to(device)

# Loss functions
l1_loss_fn = nn.L1Loss()
vgg_model = Vgg19().to(device)
perceptual_loss_fn = PerceptualLoss(vgg_model).to(device)

def combined_loss(output, target, alpha=0.2):
    l1 = l1_loss_fn(output, target)
    perceptual = perceptual_loss_fn(output, target)
    return alpha * l1 + (1 - alpha) * perceptual

# Training config
start_epoch = 22
num_epochs = 50
best_val_loss = float('inf')

# Optimizer
params = list(feature_extractor.parameters()) + \
         list(feature_correlation.parameters()) + \
         list(feature_regressor.parameters())
optimizer = optim.Adam(params, lr=1e-4)

# Scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Resume training
for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    feature_extractor.train()
    feature_correlation.train()
    feature_regressor.train()

    train_loss = 0.0
    for batch_idx, (segm, person, y, cloth) in enumerate(train_loader):
        segm, person, y, cloth = segm.to(device), person.to(device), y.to(device), cloth.to(device)

        optimizer.zero_grad()
        output = GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor)

        output = (output + 1) / 2
        y = (y + 1) / 2

        loss = combined_loss(output, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * segm.size(0)

        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx + 1}/{len(train_loader)} - Loss: {loss.item():.4f}")

    train_loss /= len(train_loader.dataset)
    print(f"Train Loss: {train_loss:.4f}")

    feature_extractor.eval()
    feature_correlation.eval()
    feature_regressor.eval()

    val_loss = 0.0
    with torch.no_grad():
        for segm, person, y, cloth in val_loader:
            segm, person, y, cloth = segm.to(device), person.to(device), y.to(device), cloth.to(device)

            output = GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor)
            output = (output + 1) / 2
            y = (y + 1) / 2

            loss = combined_loss(output, y)
            val_loss += loss.item() * segm.size(0)

    val_loss /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}")

    scheduler.step(val_loss)

    # Save model if validation improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        print("Validation loss improved. Saving model...")
        torch.save({
            'feature_extractor': feature_extractor.state_dict(),
            'feature_correlation': feature_correlation.state_dict(),
            'feature_regressor': feature_regressor.state_dict(),
        }, 'gmm_model.pth')

In [None]:
import matplotlib.pyplot as plt
import torch

def visualize_predictions(feature_extractor, feature_correlation, feature_regressor, dataloader, num_samples=4):
    feature_extractor.eval()
    feature_correlation.eval()
    feature_regressor.eval()

    # Load from best checkpoint
    checkpoint = torch.load('gmm_model.pth', map_location=device)
    feature_extractor.load_state_dict(checkpoint['feature_extractor'])
    feature_correlation.load_state_dict(checkpoint['feature_correlation'])
    feature_regressor.load_state_dict(checkpoint['feature_regressor'])

    with torch.no_grad():
        batch = next(iter(dataloader))
        segm, person, y, cloth = batch
        segm, person, y, cloth = segm.to(device), person.to(device), y.to(device), cloth.to(device)

        # Get predicted warped cloth
        preds = GMM(segm, person, y, cloth, feature_extractor, feature_correlation, feature_regressor)

        # Denormalize
        segm = segm * 0.5 + 0.5
        person = person * 0.5 + 0.5
        cloth = cloth * 0.5 + 0.5
        y = y * 0.5 + 0.5
        preds = preds * 0.5 + 0.5

        for i in range(min(num_samples, segm.size(0))):
            fig, axs = plt.subplots(1, 5, figsize=(15, 4))

            axs[0].imshow(segm[i].permute(1, 2, 0).cpu().numpy())
            axs[0].set_title("Segmentation")

            axs[1].imshow(person[i].permute(1, 2, 0).cpu().numpy())
            axs[1].set_title("Person")

            axs[2].imshow(cloth[i].permute(1, 2, 0).cpu().numpy())
            axs[2].set_title("Original Cloth")

            axs[3].imshow(preds[i].permute(1, 2, 0).cpu().numpy())
            axs[3].set_title("Warped Cloth (Predicted)")

            axs[4].imshow(y[i].permute(1, 2, 0).cpu().numpy())
            axs[4].set_title("Ground Truth (Deformed)")

            for ax in axs:
                ax.axis('off')
            plt.tight_layout()
            plt.show()

In [None]:
# After training is done
print("\n Visualizing Predictions from training Set:")
visualize_predictions(feature_extractor, feature_correlation, feature_regressor, train_loader, num_samples=4)