In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!nvidia-smi

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast
from PIL import Image
import os
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class FaceDataset(Dataset):
    def __init__(self, root_dir, img_size=256):
        self.root_dir = root_dir
        self.img_size = img_size
        self.pairs = []

        # Collect all valid pairs
        for identity in os.listdir(root_dir):
            id_path = os.path.join(root_dir, identity)
            if os.path.isdir(id_path):
                front_dir = os.path.join(id_path, 'front_face')
                side_dir = os.path.join(id_path, 'side_face')

                if os.path.exists(front_dir) and os.path.exists(side_dir):
                    fronts = [os.path.join(front_dir, f) for f in os.listdir(front_dir)
                              if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                    sides = [os.path.join(side_dir, f) for f in os.listdir(side_dir)
                             if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

                    # Create all possible combinations
                    self.pairs.extend([(s, f) for s in sides for f in fronts])

        # GPU-optimized augmentations
        self.transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(0.2, 0.2, 0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        side_path, front_path = self.pairs[idx]

        # Efficient image loading
        side_img = Image.open(side_path).convert('RGB')
        front_img = Image.open(front_path).convert('RGB')

        return {
            'side': self.transform(side_img),
            'front': self.transform(front_img)
        }

In [None]:
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        def conv_block(in_ch, out_ch, down=True):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 4, stride=2 if down else 1, padding=1, bias=False),
                nn.InstanceNorm2d(out_ch),
                nn.LeakyReLU(0.2) if down else nn.ReLU()
            )

        # Downsample
        self.down1 = conv_block(in_channels, 64)
        self.down2 = conv_block(64, 128)
        self.down3 = conv_block(128, 256)
        self.down4 = conv_block(256, 512)

        # Upsample
        self.up1 = conv_block(512, 256, down=False)
        self.up2 = conv_block(512, 128, down=False)
        self.up3 = conv_block(256, 64, down=False)
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        # Skip connections
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        u1 = self.up1(d4)
        u2 = self.up2(torch.cat([u1, d3], 1))
        u3 = self.up3(torch.cat([u2, d2], 1))
        return self.up4(torch.cat([u3, d1], 1))

class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()

        def discriminator_block(in_filters, out_filters):
            return nn.Sequential(
                nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(out_filters),
                nn.LeakyReLU(0.2)
            )

        self.model = nn.Sequential(
            discriminator_block(in_channels, 64),
            discriminator_block(64, 128),
            discriminator_block(128, 256),
            nn.Conv2d(256, 1, 4, padding=1)
        )

    def forward(self, img_A, img_B):
        return self.model(torch.cat([img_A, img_B], 1))

In [None]:
def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize models
    generator = UNetGenerator().to(device)
    discriminator = PatchGANDiscriminator().to(device)

    # Optimizers
    g_optim = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    d_optim = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

    # Loss functions
    criterion_gan = nn.MSELoss()
    criterion_l1 = nn.L1Loss()
    lambda_l1 = 100

    # Dataset and DataLoader
    dataset = FaceDataset('/content/drive/MyDrive/output')
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)

    # Mixed precision training
    scaler = GradScaler()

    # Training loop
    for epoch in range(100):
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
        for batch in pbar:
            real_side = batch['side'].to(device, non_blocking=True)
            real_front = batch['front'].to(device, non_blocking=True)

            # Train Discriminator
            with autocast():
                fake_front = generator(real_side)

                # Real loss
                pred_real = discriminator(real_side, real_front)
                loss_real = criterion_gan(pred_real, torch.ones_like(pred_real))

                # Fake loss
                pred_fake = discriminator(real_side, fake_front.detach())
                loss_fake = criterion_gan(pred_fake, torch.zeros_like(pred_fake))

                d_loss = (loss_real + loss_fake) * 0.5

            d_optim.zero_grad(set_to_none=True)
            scaler.scale(d_loss).backward()
            scaler.step(d_optim)

            # Train Generator
            with autocast():
                fake_front = generator(real_side)

                # GAN loss
                pred_fake = discriminator(real_side, fake_front)
                loss_gan = criterion_gan(pred_fake, torch.ones_like(pred_fake))

                # L1 loss
                loss_l1 = criterion_l1(fake_front, real_front) * lambda_l1

                g_total_loss = loss_gan + loss_l1

            g_optim.zero_grad(set_to_none=True)
            scaler.scale(g_total_loss).backward()
            scaler.step(g_optim)

            scaler.update()

            # Logging
            pbar.set_postfix({'D Loss': d_loss.item(), 'G Loss': g_total_loss.item()})

        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
                'epoch': epoch
            }, f'checkpoint_epoch_{epoch+1}.pth')

In [None]:
train()