In [1]:
!unzip /content/Myntra.zip

Archive:  /content/Myntra.zip
   creating: Myntra/contemporary_images/
  inflating: Myntra/contemporary_images/c1.png  
  inflating: Myntra/contemporary_images/c10.png  
  inflating: Myntra/contemporary_images/c11.png  
  inflating: Myntra/contemporary_images/c12.png  
  inflating: Myntra/contemporary_images/c13.png  
  inflating: Myntra/contemporary_images/c14.png  
  inflating: Myntra/contemporary_images/c15.png  
  inflating: Myntra/contemporary_images/c16.png  
  inflating: Myntra/contemporary_images/c17.png  
  inflating: Myntra/contemporary_images/c18.png  
  inflating: Myntra/contemporary_images/c19.png  
  inflating: Myntra/contemporary_images/c2.png  
  inflating: Myntra/contemporary_images/c20.png  
  inflating: Myntra/contemporary_images/c21.png  
  inflating: Myntra/contemporary_images/c22.png  
  inflating: Myntra/contemporary_images/c23.png  
  inflating: Myntra/contemporary_images/c24.png  
  inflating: Myntra/contemporary_images/c25.png  
  inflating: Myntra/contemporar

In [7]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os


class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.in1 = nn.InstanceNorm2d(out_channels, affine=True)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.in2 = nn.InstanceNorm2d(out_channels, affine=True)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.in1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.in2(out)
        return out + residual


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.initial = nn.Conv2d(3, 64, kernel_size=7, padding=3)
        self.initial_in = nn.InstanceNorm2d(64, affine=True)
        self.initial_relu = nn.ReLU(inplace=True)
        self.res_blocks = nn.Sequential(*[ResNetBlock(64, 64) for _ in range(9)])
        self.final = nn.Conv2d(64, 3, kernel_size=7, padding=3)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.initial_relu(self.initial_in(self.initial(x)))
        x = self.res_blocks(x)
        x = self.tanh(self.final(x))
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)


class SingleFolderDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.transform = transform
        self.images = [os.path.join(folder, img) for img in os.listdir(folder) if img.endswith(('png', 'jpg', 'jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image


def load_data(folder, batch_size=1, num_workers=1):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])
    dataset = SingleFolderDataset(folder, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


def identity_loss(real_image, same_image):
    return nn.L1Loss()(real_image, same_image) * 5.0

def train_cycle_gan(vintage_loader, contemporary_loader, num_epochs=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Initializing models
    Gv2c = Generator().to(device)  # Vintage to Contemporary
    Gc2v = Generator().to(device)  # Contemporary to Vintage
    Dv = Discriminator().to(device)  # Vintage Discriminator
    Dc = Discriminator().to(device)  # Contemporary Discriminator

    # Optimizers
    optimizer_G = optim.Adam(list(Gv2c.parameters()) + list(Gc2v.parameters()), lr=0.0002, betas=(0.5, 0.999))
    optimizer_Dv = optim.Adam(Dv.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_Dc = optim.Adam(Dc.parameters(), lr=0.0002, betas=(0.5, 0.999))

    criterion_GAN = nn.MSELoss()
    criterion_cycle = nn.L1Loss()

    os.makedirs('output3', exist_ok=True)

    for epoch in range(num_epochs):
        for i, (vintage_images, contemporary_images) in enumerate(zip(vintage_loader, contemporary_loader)):
            vintage_images = vintage_images.to(device)
            contemporary_images = contemporary_images.to(device)

            # Training Generators
            optimizer_G.zero_grad()

            # Forward cycle
            fake_contemporary = Gv2c(vintage_images)
            reconstructed_vintage = Gc2v(fake_contemporary)
            cycle_loss_vintage = criterion_cycle(reconstructed_vintage, vintage_images) * 10.0

            # Backward cycle
            fake_vintage = Gc2v(contemporary_images)
            reconstructed_contemporary = Gv2c(fake_vintage)
            cycle_loss_contemporary = criterion_cycle(reconstructed_contemporary, contemporary_images) * 10.0


            identity_loss_vintage = identity_loss(Gv2c(vintage_images), vintage_images)
            identity_loss_contemporary = identity_loss(Gc2v(contemporary_images), contemporary_images)

            # GAN loss
            loss_G = criterion_GAN(Dc(fake_contemporary), torch.ones_like(Dc(fake_contemporary))) + \
                      criterion_GAN(Dv(fake_vintage), torch.ones_like(Dv(fake_vintage))) + \
                      cycle_loss_vintage + cycle_loss_contemporary + \
                      identity_loss_vintage + identity_loss_contemporary

            loss_G.backward()
            optimizer_G.step()

            # Training Discriminators
            optimizer_Dv.zero_grad()
            loss_Dv_real = criterion_GAN(Dv(vintage_images), torch.ones_like(Dv(vintage_images)))
            loss_Dv_fake = criterion_GAN(Dv(fake_vintage.detach()), torch.zeros_like(Dv(fake_vintage.detach())))
            loss_Dv = (loss_Dv_real + loss_Dv_fake) * 0.5
            loss_Dv.backward()
            optimizer_Dv.step()

            optimizer_Dc.zero_grad()
            loss_Dc_real = criterion_GAN(Dc(contemporary_images), torch.ones_like(Dc(contemporary_images)))
            loss_Dc_fake = criterion_GAN(Dc(fake_contemporary.detach()), torch.zeros_like(Dc(fake_contemporary.detach())))
            loss_Dc = (loss_Dc_real + loss_Dc_fake) * 0.5
            loss_Dc.backward()
            optimizer_Dc.step()

            if epoch % 10 == 0:
                vutils.save_image(fake_contemporary, f"output3/fake_contemporary_epoch{epoch}_img{i}.png", normalize=True)
                vutils.save_image(fake_vintage, f"output3/fake_vintage_epoch{epoch}_img{i}.png", normalize=True)
                print(f'Epoch {epoch}, Loss G: {loss_G.item()}, Loss Dv: {loss_Dv.item()}, Loss Dc: {loss_Dc.item()}')

    torch.save(Gv2c.state_dict(), 'Gv2c.pth')
    torch.save(Gc2v.state_dict(), 'Gc2v.pth')

# Loading data
vintage_loader = load_data('/content/Myntra/vintage_images')
contemporary_loader = load_data('/content/Myntra/contemporary_images')

train_cycle_gan(vintage_loader, contemporary_loader, num_epochs=100)


Using device: cuda
Epoch 0, Loss G: 26.211318969726562, Loss Dv: 0.6351184844970703, Loss Dc: 0.7062029242515564
Epoch 0, Loss G: 21.143856048583984, Loss Dv: 0.9943298101425171, Loss Dc: 0.5488073229789734
Epoch 0, Loss G: 21.001811981201172, Loss Dv: 2.5828819274902344, Loss Dc: 0.6301491260528564
Epoch 0, Loss G: 20.297481536865234, Loss Dv: 0.5936489105224609, Loss Dc: 1.655055046081543
Epoch 0, Loss G: 19.918197631835938, Loss Dv: 0.3101159930229187, Loss Dc: 0.7526329755783081
Epoch 0, Loss G: 17.662368774414062, Loss Dv: 0.28306856751441956, Loss Dc: 0.5356406569480896
Epoch 0, Loss G: 18.75716209411621, Loss Dv: 0.2500198185443878, Loss Dc: 0.42766404151916504
Epoch 0, Loss G: 15.579483985900879, Loss Dv: 0.35815221071243286, Loss Dc: 0.9334192276000977
Epoch 0, Loss G: 14.445233345031738, Loss Dv: 0.17108666896820068, Loss Dc: 0.5087600350379944
Epoch 0, Loss G: 16.093778610229492, Loss Dv: 0.2526012659072876, Loss Dc: 0.5887030363082886
Epoch 0, Loss G: 16.452838897705078, Lo

In [5]:
def generate_images(Gv2c, Gc2v, vintage_image_path, contemporary_image_path):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    Gv2c.to(device)
    Gc2v.to(device)
    Gv2c.eval()
    Gc2v.eval()

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
    ])

    vintage_image = transform(Image.open(vintage_image_path).convert('RGB')).unsqueeze(0).to(device)
    contemporary_image = transform(Image.open(contemporary_image_path).convert('RGB')).unsqueeze(0).to(device)

    with torch.no_grad():
        fake_contemporary = Gv2c(vintage_image)
        fake_vintage = Gc2v(contemporary_image)
        mixed_image = (vintage_image + fake_contemporary + contemporary_image + fake_vintage) / 4.0

    save_image(fake_contemporary, 'output3/fake_contemporary2.png', normalize=True)
    save_image(fake_vintage, 'output3/fake_vintage2.png', normalize=True)
    save_image(mixed_image, 'output3/mixed_image2.png', normalize=True)


Gv2c = Generator()
Gc2v = Generator()
Gv2c.load_state_dict(torch.load('Gv2c.pth'))
Gc2v.load_state_dict(torch.load('Gc2v.pth'))

generate_images(Gv2c, Gc2v, '/content/Myntra/vintage_images/v2.png', '/content/Myntra/contemporary_images/c1.png')
