In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np

In [2]:
# Hyperparameters
num_samples = 1000
image_size = 64
num_epochs = 50
batch_size = 64
learning_rate = 0.0002

In [4]:
# Generate synthetic dataset
clean_images = torch.rand(num_samples, 1, image_size, image_size)
noise = torch.randn_like(clean_images) * 0.5
damaged_images = clean_images + noise


In [5]:
# Create DataLoader
dataset = TensorDataset(damaged_images, clean_images)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [6]:
# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # Define architecture
            nn.Conv2d(1, 32, 4, 2, 1),   # input size: 1x64x64, output size: 32x32x32
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1),  # input size: 32x32x32, output size: 64x16x16
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), # input size: 64x16x16, output size: 128x8x8
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 1, 4, 2, 1),  # input size: 128x8x8, output size: 1x4x4
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)



In [7]:
# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            # Define architecture
            nn.Conv2d(2, 32, 4, 2, 1),   # input size: 2x64x64, output size: 32x32x32
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1),  # input size: 32x32x32, output size: 64x16x16
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), # input size: 64x16x16, output size: 128x8x8
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 2, 1),  # input size: 128x8x8, output size: 1x4x4
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [8]:
# Initialize generator and discriminator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)



In [10]:
# Loss function
criterion = nn.BCELoss()


In [11]:
# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [19]:
# Training loop
for epoch in range(num_epochs):
    for i, (damaged_images, clean_images) in enumerate(data_loader):
        # Adversarial ground truths
        valid = torch.ones(clean_images.size(0), 1, 4, 4).to(device)
        fake = torch.zeros(clean_images.size(0), 1, 4, 4).to(device)
        # Configure input
        damaged_images = damaged_images.to(device)
        clean_images = clean_images.to(device)
         # Train Discriminator
        optimizer_D.zero_grad()



In [24]:
# Generate a batch of images
generated_images = generator(damaged_images)

# Resize generated_images to match the spatial dimensions of damaged_images
generated_images_resized = nn.functional.interpolate(generated_images, size=(64, 64), mode='bilinear', align_corners=False)

# Measure discriminator's ability to classify fake images
fake_loss = criterion(discriminator(torch.cat((damaged_images, generated_images_resized), dim=1)), fake)


In [28]:
# Training loop
for epoch in range(num_epochs):
    for i, (damaged_images, clean_images) in enumerate(data_loader):
        # Adversarial ground truths
        valid = torch.ones(clean_images.size(0), 1, 4, 4).to(device)
        fake = torch.zeros(clean_images.size(0), 1, 4, 4).to(device)

        # Configure input
        damaged_images = damaged_images.to(device)
        clean_images = clean_images.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()





In [31]:
# Ensure both tensors have the same batch size
batch_size = damaged_images.size(0)
generated_images = generated_images[:batch_size]

# Resize generated_images to match the spatial dimensions of damaged_images
generated_images_resized = torch.nn.functional.interpolate(generated_images, size=(damaged_images.size(2), damaged_images.size(3)), mode='bilinear', align_corners=False)

# Print out sizes for debugging
print("Damaged Images Size:", damaged_images.size())
print("Generated Images Resized Size:", generated_images_resized.size())

# Measure discriminator's ability to classify fake images
fake_loss = criterion(discriminator(torch.cat((damaged_images, generated_images_resized), dim=1)), fake)


Damaged Images Size: torch.Size([40, 1, 64, 64])
Generated Images Resized Size: torch.Size([40, 1, 64, 64])
