In [4]:
!pip install torch 

Collecting torch
  Using cached torch-2.4.1-cp38-cp38-win_amd64.whl.metadata (27 kB)
Collecting filelock (from torch)
  Using cached filelock-3.16.1-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy (from torch)
  Using cached sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.1-py3-none-any.whl.metadata (5.3 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.4.1-cp38-cp38-win_amd64.whl (199.4 MB)
   ---------------------------------------- 0.0/199.4 MB ? eta -:--:--
   ---------------------------------------- 0.5/199.4 MB 3.3 MB/s eta 0:01:00
   ---------------------------------------- 0.8/199.4 MB 3.0 MB/s eta 0:01:06
   ---------------------------------------- 1.0/199.4 MB 2.4 MB/s eta 0:01:23
   ---------------------------------------- 1.3/199.4 M

In [19]:
!pip install torchvision



In [None]:
import torch
import torch.nn as nn
from torchvision import models

class CoarseGenerator(nn.Module):
    def __init__(self):
        super(CoarseGenerator, self).__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.encoder.fc = nn.Identity()  # Remove the final layer
        self.upconv1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.upconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), 512, 1, 1)  # Reshape for upsampling
        x = self.upconv1(x)
        x = self.upconv2(x)
        x = self.upconv3(x)
        x = self.final_conv(x)
        return torch.tanh(x)  # Return normalized image


In [12]:
class RefinementNetwork(nn.Module):
    def __init__(self):
        super(RefinementNetwork, self).__init__()
        self.encoder = models.resnet18(pretrained=True)
        self.encoder.fc = nn.Identity()  # Removing the final layer
        self.upconv1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.upconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), 512, 1, 1)
        x = self.upconv1(x)
        x = self.upconv2(x)
        x = self.upconv3(x)
        x = self.final_conv(x)
        return torch.tanh(x)


In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import cv2

# Dataset class to load user and clothing images
class TryOnDataset(Dataset):
    def __init__(self, user_images, clothing_images, transform=None):
        self.user_images = user_images
        self.clothing_images = clothing_images
        self.transform = transform

    def __len__(self):
        return len(self.user_images)

    def __getitem__(self, idx):
        user_image = cv2.imread(self.user_images[idx])
        clothing_image = cv2.imread(self.clothing_images[idx])
        if self.transform:
            user_image = self.transform(user_image)
            clothing_image = self.transform(clothing_image)
        return user_image, clothing_image

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 192)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load dataset
dataset = TryOnDataset(user_images, clothing_images, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:
import torch.optim as optim

generator = CoarseGenerator().cuda()
refiner = RefinementNetwork().cuda()
criterion = nn.MSELoss()

optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_r = optim.Adam(refiner.parameters(), lr=0.0002)

for epoch in range(num_epochs):
    for i, (user_image, clothing_image) in enumerate(dataloader):
        user_image = user_image.cuda()
        clothing_image = clothing_image.cuda()

        # Coarse image generation
        optimizer_g.zero_grad()
        coarse_output = generator(user_image)
        loss_g = criterion(coarse_output, clothing_image)  # Add a suitable loss
        loss_g.backward()
        optimizer_g.step()

        # Fine image refinement
        optimizer_r.zero_grad()
        refined_output = refiner(coarse_output)
        loss_r = criterion(refined_output, clothing_image)  # Loss function for refinement
        loss_r.backward()
        optimizer_r.step()

        if i % 10 == 0:
            print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], '
                  f'Coarse Loss: {loss_g.item():.4f}, Refine Loss: {loss_r.item():.4f}')


In [15]:
def try_on(user_image, clothing_image):
    # Preprocess input
    user_image = preprocess_image(user_image)
    clothing_image = preprocess_image(clothing_image)
    
    # Generate coarse image
    coarse_output = generator(user_image)
    
    # Refine the image
    refined_output = refiner(coarse_output)
    
    return postprocess_image(refined_output)
