In [None]:
!pip install torch torchvision numpy pillow tqdm matplotlib

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
import os

# Generator (U-Net)
class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, output_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Discriminator (PatchGAN)
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [3]:
class SunsetDataset(Dataset):
    def __init__(self, sunset_path, daytime_path, transform):
        self.sunset_images = sorted(os.listdir(sunset_path))
        self.daytime_images = sorted(os.listdir(daytime_path))
        self.sunset_path = sunset_path
        self.daytime_path = daytime_path
        self.transform = transform

    def __len__(self):
        return min(len(self.sunset_images), len(self.daytime_images))

    def __getitem__(self, idx):
        sunset_img = Image.open(os.path.join(self.sunset_path, self.sunset_images[idx])).convert("RGB")
        daytime_img = Image.open(os.path.join(self.daytime_path, self.daytime_images[idx])).convert("RGB")

        sunset_img = self.transform(sunset_img)
        daytime_img = self.transform(daytime_img)

        return sunset_img, daytime_img

# Define Transformations
IMG_SIZE = 256
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Load dataset
dataset = SunsetDataset("sunrise", "daytime", transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [6]:
# Losses
criterion_gan = nn.BCELoss()
criterion_cycle = nn.L1Loss()

# Initialize Models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [9]:
def train_cycleGAN(generator, discriminator, dataloader, epochs):
    for epoch in range(epochs):
        for i, (sunset, daytime) in enumerate(dataloader):
            sunset, daytime = sunset.to(device), daytime.to(device)

            # Train Generator
            optimizer_g.zero_grad()
            fake_daytime = generator(sunset)
            d_output = discriminator(fake_daytime)
            real_labels = torch.ones_like(d_output, device=device)
            fake_labels = torch.zeros_like(d_output, device=device)
            g_loss = criterion_gan(d_output, real_labels) + 100 * criterion_cycle(fake_daytime, daytime)
            g_loss.backward()
            optimizer_g.step()

            # Train Discriminator
            optimizer_d.zero_grad()
            real_output = discriminator(daytime)
            fake_output = discriminator(fake_daytime.detach())
            d_loss = (criterion_gan(real_output, real_labels) + criterion_gan(fake_output, fake_labels)) / 2
            d_loss.backward()
            optimizer_d.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    torch.save(generator.state_dict(), "generator5.pth")
    torch.save(discriminator.state_dict(), "discriminator5.pth")

# Train CycleGAN
train_cycleGAN(generator, discriminator, dataloader, epochs=5)

Epoch [0/5], Step [0/115], D Loss: 0.1247502863407135, G Loss: 50.19077682495117
Epoch [0/5], Step [100/115], D Loss: 0.22471849620342255, G Loss: 50.95927810668945
Epoch [1/5], Step [0/115], D Loss: 0.18064631521701813, G Loss: 49.57779312133789
Epoch [1/5], Step [100/115], D Loss: 0.2735564112663269, G Loss: 49.95952606201172
Epoch [2/5], Step [0/115], D Loss: 0.24549731612205505, G Loss: 48.38978576660156
Epoch [2/5], Step [100/115], D Loss: 0.29486003518104553, G Loss: 42.84647750854492
Epoch [3/5], Step [0/115], D Loss: 0.17149807512760162, G Loss: 47.124305725097656
Epoch [3/5], Step [100/115], D Loss: 0.4019019901752472, G Loss: 45.79384231567383
Epoch [4/5], Step [0/115], D Loss: 0.2299571931362152, G Loss: 47.204559326171875
Epoch [4/5], Step [100/115], D Loss: 0.12288659065961838, G Loss: 52.82942199707031


In [10]:
def infer(generator, image_path, transform):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    generator.eval()
    with torch.no_grad():
        corrected_image = generator(image).to(device)
    save_image(corrected_image, "corrected_image.png")

generator.load_state_dict(torch.load("generator5.pth"))
discriminator.load_state_dict(torch.load("discriminator5.pth"))

# Run inference
infer(generator, "Image2.jpg", transform)

  generator.load_state_dict(torch.load("generator5.pth"))
  discriminator.load_state_dict(torch.load("discriminator5.pth"))
