In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import os

batch_size = 16
lr = 0.0002
z_dim = 100
img_size = 64
epochs = 1000 

output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

class DualImageDataset(Dataset):
    def __init__(self, path1, path2, transform=None):
        self.images1 = [os.path.join(path1, img) for img in os.listdir(path1)]
        self.images2 = [os.path.join(path2, img) for img in os.listdir(path2)]
        self.transform = transform

    def __len__(self):
        return min(len(self.images1), len(self.images2))

    def __getitem__(self, idx):
        img1 = Image.open(self.images1[idx]).convert('RGB')
        img2 = Image.open(self.images2[idx]).convert('RGB')

        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        return img1, img2

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

path1 = 'D:\Gan\dataset1'
path2 = 'D:\Gan\dataset2'
dataset = DualImageDataset(path1, path2, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

class Generator(nn.Module):
    def __init__(self, z_dim, img_channels=3):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim + 2 * img_size * img_size * img_channels, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_channels * img_size * img_size),
            nn.Tanh()
        )

    def forward(self, noise, img1, img2):
        x = torch.cat((noise, img1.view(img1.size(0), -1), img2.view(img2.size(0), -1)), dim=1)
        return self.gen(x).view(-1, 3, img_size, img_size)

class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(3 * img_size * img_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.disc(img.view(img.size(0), -1))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

optim_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optim_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

def save_generated_images(epoch, images, output_dir):
    os.makedirs(f"{output_dir}/epoch_{epoch+1}", exist_ok=True)
    for i, img in enumerate(images):
        save_path = f"{output_dir}/epoch_{epoch+1}/generated_{i+1}.png"
        img = ((img + 1) / 2).clamp(0, 1)  
        transforms.ToPILImage()(img).save(save_path)

for epoch in range(epochs):
    for img1, img2 in dataloader:
        img1, img2 = img1.to(device), img2.to(device)
        batch_size = img1.size(0)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        noise = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(noise, img1, img2)
        real_preds = discriminator(img1)
        fake_preds = discriminator(fake_imgs.detach())
        loss_D = criterion(real_preds, real_labels) + criterion(fake_preds, fake_labels)

        optim_D.zero_grad()
        loss_D.backward()
        optim_D.step()

        fake_preds = discriminator(fake_imgs)
        loss_G = criterion(fake_preds, real_labels)

        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {loss_D:.4f}, Loss G: {loss_G:.4f}")
    save_generated_images(epoch, fake_imgs, output_dir)


Epoch [1/1000] Loss D: 0.7955, Loss G: 0.7163
Epoch [2/1000] Loss D: 1.0266, Loss G: 0.7187
Epoch [3/1000] Loss D: 1.0458, Loss G: 0.7184
Epoch [4/1000] Loss D: 1.1381, Loss G: 0.7795
Epoch [5/1000] Loss D: 1.4423, Loss G: 0.7428
Epoch [6/1000] Loss D: 1.2095, Loss G: 0.7462
Epoch [7/1000] Loss D: 1.7498, Loss G: 0.9976
Epoch [8/1000] Loss D: 1.2771, Loss G: 0.9539
Epoch [9/1000] Loss D: 1.0966, Loss G: 0.8596
Epoch [10/1000] Loss D: 1.2038, Loss G: 1.3570
Epoch [11/1000] Loss D: 1.2741, Loss G: 0.7362
Epoch [12/1000] Loss D: 1.0424, Loss G: 0.7838
Epoch [13/1000] Loss D: 1.2450, Loss G: 0.8712
Epoch [14/1000] Loss D: 1.1940, Loss G: 0.7734
Epoch [15/1000] Loss D: 1.2359, Loss G: 0.8332
Epoch [16/1000] Loss D: 1.2170, Loss G: 0.7598
Epoch [17/1000] Loss D: 1.0417, Loss G: 0.6385
Epoch [18/1000] Loss D: 1.2298, Loss G: 0.8733
Epoch [19/1000] Loss D: 0.9801, Loss G: 0.8293
Epoch [20/1000] Loss D: 2.6032, Loss G: 0.7991
Epoch [21/1000] Loss D: 0.8119, Loss G: 0.6892
Epoch [22/1000] Loss D

In [2]:
64*64

4096