<a href="https://colab.research.google.com/github/Anujsharmagithubbb/Task4_Prodigy-infotech/blob/main/Task4_Prodigy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import os
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

# -----------------------------
# Config
# -----------------------------
class Config:
    dataset_root = "./data/facades"  # contains 'train' and 'val' folders
    batch_size = 8
    image_size = 256
    epochs = 200
    lr = 2e-4
    beta1 = 0.5
    lambda_l1 = 100.0
    checkpoint_dir = "./checkpoints"
    sample_dir = "./samples"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_workers = 2
    save_every = 5

cfg = Config()

os.makedirs(cfg.checkpoint_dir, exist_ok=True)
os.makedirs(cfg.sample_dir, exist_ok=True)

# -----------------------------
# Dataset loader
# -----------------------------
class PairedImageDataset(Dataset):
    def __init__(self, root: str, mode: str = "train", image_size: int = 256):
        self.dir = Path(root) / mode
        self.files = sorted([p for p in self.dir.iterdir() if p.suffix.lower() in {'.jpg', '.png', '.jpeg'}])
        self.size = image_size

        self.transform = transforms.Compose([
            transforms.Resize((self.size, self.size)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx: int):
        path = self.files[idx]
        img = Image.open(path).convert('RGB')
        w, h = img.size
        w2 = w // 2
        A = img.crop((0, 0, w2, h))
        B = img.crop((w2, 0, w, h))

        A = self.transform(A)
        B = self.transform(B)

        A = (A - 0.5) * 2.0
        B = (B - 0.5) * 2.0

        return {'A': A, 'B': B, 'path': str(path)}

# -----------------------------
# Network blocks
# -----------------------------
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x
