In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ImagePairDataset(Dataset):
    def __init__(self, comp_dir, lg_dir, transform=None):
        self.comp_dir = comp_dir
        self.lg_dir = lg_dir
        self.transform = transform
        self.image_names = [f for f in os.listdir(comp_dir) if os.path.isfile(os.path.join(comp_dir, f))]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        comp_image_name = self.image_names[idx]
        lg_image_name = comp_image_name  # Assuming the paired images have the same name

        comp_image = Image.open(os.path.join(self.comp_dir, comp_image_name)).convert("L")
        lg_image = Image.open(os.path.join(self.lg_dir, lg_image_name)).convert("L")

        if self.transform:
            comp_image = self.transform(comp_image)
            lg_image = self.transform(lg_image)

        return comp_image, lg_image

transform = transforms.Compose([
    transforms.Resize((64, 2048)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize for grayscale images
])

comp_dir = '/Selected_Comp'
lg_dir = '/Selected_LG'
dataset = ImagePairDataset(comp_dir, lg_dir, transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)


In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np

def create_window(window_size, channel):
    _1D_window = torch.tensor(np.hanning(window_size), dtype=torch.float32)
    _2D_window = _1D_window.unsqueeze(1) * _1D_window.unsqueeze(0)
    _2D_window = _2D_window.unsqueeze(0).unsqueeze(0)
    window = torch.tensor(_2D_window.expand(channel, 1, window_size, window_size).contiguous(), dtype=torch.float32)
    return window

class SSIMLoss(nn.Module):
    def __init__(self, window_size=11, channel=1):
        super(SSIMLoss, self).__init__()
        self.window_size = window_size
        self.channel = channel
        self.window = create_window(window_size, channel).to('cuda')

    def forward(self, img1, img2):
        mu1 = F.conv2d(img1, self.window.type_as(img1), padding=self.window_size//2, groups=self.channel)
        mu2 = F.conv2d(img2, self.window.type_as(img2), padding=self.window_size//2, groups=self.channel)
        
        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2
        
        sigma1_sq = F.conv2d(img1 * img1, self.window.type_as(img1), padding=self.window_size//2, groups=self.channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, self.window.type_as(img2), padding=self.window_size//2, groups=self.channel) - mu2_sq
        sigma12 = F.conv2d(img1 * img2, self.window.type_as(img1), padding=self.window_size//2, groups=self.channel) - mu1_mu2
        
        C1 = 0.01 ** 2
        C2 = 0.03 ** 2
        
        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        return torch.clamp((1 - ssim_map) / 2, 0, 1).mean()


In [3]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from diffusers import UNet2DModel, DDPMScheduler

class ConditionalDDPM(pl.LightningModule):
    def __init__(self):
        super(ConditionalDDPM, self).__init__()
        self.model = UNet2DModel(
            sample_size=(64, 2048),
            in_channels=2,  # Two channels: one for the noisy image, one for the condition image
            out_channels=1,  # Single channel output for grayscale images
            layers_per_block=4,
            block_out_channels=(64, 128, 256, 512, 1024),
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
        )
        self.scheduler = DDPMScheduler(num_train_timesteps=1000)
        self.criterion_mse = nn.MSELoss()
        self.criterion_ssim = SSIMLoss()

    def forward(self, x, t, condition):
        x = torch.cat([x, condition], dim=1)  # Concatenate the noisy image and the condition image along the channel dimension
        return self.model(x, t).sample

    def training_step(self, batch, batch_idx):
        comp_images, lg_images = batch
        comp_images = comp_images.to(self.device, dtype=torch.float32)
        lg_images = lg_images.to(self.device, dtype=torch.float32)

        t = torch.randint(0, self.scheduler.config.num_train_timesteps, (comp_images.size(0),), device=self.device).long()
        noise = torch.randn_like(comp_images).to(self.device, dtype=torch.float32)
        noisy_images = self.scheduler.add_noise(original_samples=comp_images, noise=noise, timesteps=t)

        predicted_noise = self(noisy_images, t, comp_images)
        mse_loss = self.criterion_mse(predicted_noise, noise)

        # Generate the denoised images
        denoised_images = noisy_images - predicted_noise

        # Calculate SSIM loss on the denoised images and the target images
        ssim_loss = self.criterion_ssim(denoised_images, lg_images)
        
        loss = mse_loss + ssim_loss
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)


In [4]:
from pytorch_lightning import Trainer

trainer = Trainer(
    accumulate_grad_batches=4,  # Gradient accumulation
    precision=16,  # Mixed precision
    max_epochs=100,  # Increased number of epochs
    accelerator='gpu',
    devices=1
)

model = ConditionalDDPM()

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
  window = torch.tensor(_2D_window.expand(channel, 1, window_size, window_size).contiguous(), dtype=torch.float32)


In [5]:
trainer.fit(model, dataloader)

You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /lightning_logs
2024-07-12 12:28:04.196995: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-12 12:28:04.213232: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-12 12:28:04.213261: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one h

Training: |          | 0/? [00:00<?, ?it/s]

OSError: [Errno 28] No space left on device

In [None]:
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
# Select a random image from the dataset for conditioning
random_idx = random.randint(0, len(dataset) - 1)
comp_image, _ = dataset[random_idx]
comp_image = comp_image.unsqueeze(0).to('cuda')  # Add batch dimension and move to GPU

model.eval()
with torch.no_grad():
    num_images = 1  # Generate one image conditioned on the input
    noise = torch.randn(num_images, 1, 64, 2048).to('cuda')  # Initial noise
    for t in tqdm(reversed(range(model.scheduler.config.num_train_timesteps))):
        model_output = model(noise, t, comp_image)
        noise = model.scheduler.step(model_output, t, noise).prev_sample

    generated_image = (noise + 1) / 2  # Convert [-1, 1] to [0, 1]

# Plot the input image
plt.figure(figsize=(20, 10))

plt.subplot(2, 1, 1)
plt.imshow(comp_image.squeeze().cpu(), cmap='gray')
plt.title('Input Image (Computer Font)')
plt.axis("off")

# Plot the generated image
plt.subplot(2, 1, 2)
plt.imshow(generated_image.squeeze().cpu(), cmap='gray')
plt.title('Generated Image (LG Style)')
plt.axis("off")

plt.show()
