In [None]:
pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import glob
import random

#import torchvision.transforms as transforms

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Improved Generator (U-Net architecture for better results)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Example U-Net like architecture
        self.down1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        u1 = self.up1(d2)
        u2 = self.up2(u1)
        return u2

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

Using device: mps


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [5]:
from Dataset_get import CelebAPairedDataset

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [6]:
if __name__ == "__main__":
    # Initialize dataset and dataloader
    dataset = CelebAPairedDataset(root_dir='/Volumes/Vids/CelebA/output', transform=transform)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)


    # Initialize models
    G_A2B = Generator().to(device)
    G_B2A = Generator().to(device)
    D_A = Discriminator().to(device)
    D_B = Discriminator().to(device)

    # Define optimizers
    optimizer_G = optim.Adam(list(G_A2B.parameters()) + list(G_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

    # Loss functions
    criterion_GAN = nn.MSELoss().to(device)
    criterion_cycle = nn.L1Loss().to(device)

    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        for i, (side_img, front_img) in enumerate(dataloader):
            # Move images to MPS device
            side_img = side_img.to(device)
            front_img = front_img.to(device)

            # Train Generators
            optimizer_G.zero_grad()

            # Generate fake front and side images
            fake_front = G_A2B(side_img)
            fake_side = G_B2A(front_img)

            # Adversarial loss
            loss_GAN_A2B = criterion_GAN(D_B(fake_front), torch.ones_like(D_B(fake_front)))
            loss_GAN_B2A = criterion_GAN(D_A(fake_side), torch.ones_like(D_A(fake_side)))

            # Cycle consistency loss
            reconstructed_side = G_B2A(fake_front)
            reconstructed_front = G_A2B(fake_side)
            loss_cycle_A = criterion_cycle(reconstructed_side, side_img)
            loss_cycle_B = criterion_cycle(reconstructed_front, front_img)

            # Total generator loss
            loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B
            loss_G.backward()
            optimizer_G.step()

            # Train Discriminator A
            optimizer_D_A.zero_grad()
            loss_D_A = criterion_GAN(D_A(side_img), torch.ones_like(D_A(side_img))) + \
                       criterion_GAN(D_A(fake_side.detach()), torch.zeros_like(D_A(fake_side.detach())))
            loss_D_A.backward()
            optimizer_D_A.step()

            # Train Discriminator B
            optimizer_D_B.zero_grad()
            loss_D_B = criterion_GAN(D_B(front_img), torch.ones_like(D_B(front_img))) + \
                       criterion_GAN(D_B(fake_front.detach()), torch.zeros_like(D_B(fake_front.detach())))
            loss_D_B.backward()
            optimizer_D_B.step()

            # Print losses and save images every 100 iterations
            if i % 500 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Batch [{i}/{len(dataloader)}] '
                      f'Loss D_A: {loss_D_A.item()}, Loss D_B: {loss_D_B.item()}, Loss G: {loss_G.item()}')

                # Save generated images
                with torch.no_grad():
                    fake_front = G_A2B(side_img)
                    fake_side = G_B2A(front_img)

                    # Denormalize images for saving
                    def denormalize(tensor):
                        return tensor * 0.5 + 0.5

                    save_image(denormalize(fake_front.cpu()), f'generated_images/epoch_{epoch}_batch_{i}_fake_front.png')
                    save_image(denormalize(side_img.cpu()), f'generated_images/epoch_{epoch}_batch_{i}_real_side.png')
                    save_image(denormalize(fake_side.cpu()), f'generated_images/epoch_{epoch}_batch_{i}_fake_side.png')
                    save_image(denormalize(front_img.cpu()), f'generated_images/epoch_{epoch}_batch_{i}_real_front.png')

    # Save models
    torch.save(G_A2B.state_dict(), 'G_A2B.pth')
    torch.save(G_B2A.state_dict(), 'G_B2A.pth')
    torch.save(D_A.state_dict(), 'D_A.pth')
    torch.save(D_B.state_dict(), 'D_B.pth')

Found 1088718 valid image pairs
Epoch [0/10] Batch [0/68045] Loss D_A: 0.5177501440048218, Loss D_B: 0.5173390507698059, Loss G: 1.8870849609375
Epoch [0/10] Batch [500/68045] Loss D_A: 0.03433552756905556, Loss D_B: 0.060886166989803314, Loss G: 2.09153413772583
Epoch [0/10] Batch [1000/68045] Loss D_A: 0.0007825492066331208, Loss D_B: 0.017512032762169838, Loss G: 2.325256109237671
Epoch [0/10] Batch [1500/68045] Loss D_A: 0.0003827515465673059, Loss D_B: 0.056066982448101044, Loss G: 2.0809197425842285
Epoch [0/10] Batch [2000/68045] Loss D_A: 0.0012499794829636812, Loss D_B: 0.014400693587958813, Loss G: 2.314405679702759
Epoch [0/10] Batch [2500/68045] Loss D_A: 8.648333459859714e-05, Loss D_B: 0.016238341107964516, Loss G: 2.219290256500244
Epoch [0/10] Batch [3000/68045] Loss D_A: 0.00520399771630764, Loss D_B: 0.00017489951278548688, Loss G: 2.2164101600646973
Epoch [0/10] Batch [3500/68045] Loss D_A: 0.0002016467333305627, Loss D_B: 0.00011920143879251555, Loss G: 2.2491574287

KeyboardInterrupt: 