In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from transformers import AutoModel, AutoFeatureExtractor

Loading Pre Trained Model and Feature Extractor

In [None]:
model_name = "gouthaml/raos-virtual-try-on-model" #DeepVTO Model
model = AutoModel.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

Prepare the Dataset

In [None]:
class FashionDataset(Dataset):
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        target = Image.open(self.targets[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
            target = self.transform(target)
        return image, target

# Path
train_images = ["path"]
train_targets = ["path"]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = FashionDataset(train_images, train_targets, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

Chnages in Model Architecture

In [None]:
class CustomDeepVTOModel(nn.Module):
    def __init__(self, pretrained_model):
        super(CustomDeepVTOModel, self).__init__()
        self.base_model = pretrained_model
        self.unet = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        features = self.base_model(x).last_hidden_state
        features = features.permute(0, 2, 1).contiguous().view(features.size(0), 768, 14, 14)
        x = self.unet(features)
        return x

custom_model = CustomDeepVTOModel(model)

Saving New Model

In [None]:
custom_model = CustomDeepVTOModel(model)

Training Loop

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(custom_model.parameters(), lr=1e-4)

num_epochs = 10
for epoch in range(num_epochs):
    custom_model.train()
    running_loss = 0.0
    for images, targets in train_loader:
        optimizer.zero_grad()
        inputs = feature_extractor(images, return_tensors="pt").pixel_values
        outputs = custom_model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}')



Download new model for future use

In [None]:
torch.save(custom_model.state_dict(), 'fine_tuned_deepvto_model.pth')