In [16]:
import os
import torch
import numpy as np
import cv2
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from diffusers import UNet2DModel, DDPMScheduler
import matplotlib.pyplot as plt
from tqdm import tqdm

# Custom dataset class with pre-computed text images
class CustomDataset(Dataset):
    def __init__(self, image_dir, text_image_dir, transform=None, num_images=2000):
        self.image_dir = image_dir
        self.text_image_dir = text_image_dir
        self.filenames = os.listdir(image_dir)[:num_images]
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.filenames[idx])
        text_img_name = os.path.join(self.text_image_dir, self.filenames[idx])

        image = Image.open(img_name).convert("L")  # Convert to grayscale
        text_image = Image.open(text_img_name).convert("L")  # Convert to grayscale

        if self.transform:
            image = self.transform(image)
            text_image = self.transform(text_image)

        return image, text_image

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((64, 2048)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize for grayscale images
])

# Define the dataset and dataloader
image_dir = '/images'  # Replace with your actual image directory
text_image_dir = '/transcriptions'  # Replace with your actual transcription image directory
dataset = CustomDataset(image_dir=image_dir, text_image_dir=text_image_dir, transform=transform, num_images=2000)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8)

In [17]:
class SimpleConditionalDDPMWithImage(pl.LightningModule):
    def __init__(self):
        super(SimpleConditionalDDPMWithImage, self).__init__()
        self.model = UNet2DModel(
            sample_size=(64, 2048),
            in_channels=2,  # Update to accept 2 channels (image + text image)
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512),
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
        )
        self.scheduler = DDPMScheduler(num_train_timesteps=1000)
        self.criterion = torch.nn.MSELoss()

    def forward(self, x, t, text_image):
        # Concatenate image and text image
        x = torch.cat((x, text_image), dim=1)
        return self.model(x, t).sample

    def training_step(self, batch, batch_idx):
        images, text_images = batch
        images = images.to(self.device)
        text_images = text_images.to(self.device)
        t = torch.randint(0, self.scheduler.config.num_train_timesteps, (images.size(0),), device=self.device).long()
        noise = torch.randn_like(images).to(self.device)
        noisy_images = self.scheduler.add_noise(original_samples=images, noise=noise, timesteps=t)
        predicted_noise = self(noisy_images, t, text_images)
        loss = self.criterion(predicted_noise, noise)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

# Initialize the model with image-based condition
model_with_image = SimpleConditionalDDPMWithImage()

# Define the trainer
trainer_with_image = pl.Trainer(
    precision=16,  # Mixed precision
    max_epochs=50,  # Continue training for another 50 epochs
    accelerator='gpu',
    devices=1
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [20]:
# Train the model with image-based condition
trainer_with_image.fit(model_with_image, dataloader)

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /lightning_logs/version_33/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type        | Params | Mode
-------------------------------------------------
0 | model     | UNet2DModel | 56.6 M | eval
1 | criterion | MSELoss     | 0      | eval
-------------------------------------------------
56.6 M    Trainable params
0         Non-trainable params
56.6 M    Total params
226.291   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [None]:
import random
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

# Function to generate conditional images from noise using a random text image from training data
def generate_conditional_images(model, scheduler, dataset, num_images=8, device='cuda'):
    model.to(device)  # Ensure model is on the correct device
    model.eval()
    fig, axs = plt.subplots(3 * num_images, 1, figsize=(20, 7.5))
    with torch.no_grad():
        for i in tqdm(range(num_images)):
            # Select a random image and text image pair from the dataset
            random_idx = random.randint(0, len(dataset) - 1)
            image, random_text_image = dataset[random_idx]
            random_text_image = random_text_image.to(device).unsqueeze(0)
            noise = torch.randn((1, 1, 64, 2048)).to(device)
            for t in tqdm(range(scheduler.config.num_train_timesteps - 1, -1, -1), leave=False):
                t_tensor = torch.tensor([t], device=device).long()
                predicted_noise = model(noise, t_tensor, random_text_image)
                noise = scheduler.step(model_output=predicted_noise, timestep=t_tensor, sample=noise).prev_sample

            # Display the generated image
            axs[3 * i].imshow(noise.squeeze().cpu().numpy(), cmap='gray')
            axs[3 * i].set_title("Generated Image")
            axs[3 * i].axis('off')

            # Display the ground truth image
            axs[3 * i + 1].imshow(image.permute(1, 2, 0).cpu().numpy(), cmap='gray')
            axs[3 * i + 1].set_title("Ground Truth Image")
            axs[3 * i + 1].axis('off')

            # Display the condition image
            axs[3 * i + 2].imshow(random_text_image.squeeze().cpu().numpy(), cmap='gray')
            axs[3 * i + 2].set_title("Condition Image")
            axs[3 * i + 2].axis('off')

    plt.tight_layout()
    plt.show()

# Generate and display conditional images using random text images from the training data
generate_conditional_images(model_with_image, model_with_image.scheduler, dataset, num_images=2, device='cuda')


  0%|          | 0/2 [00:00<?, ?it/s][A

  0%|          | 0/1000 [00:00<?, ?it/s][A[A

  0%|          | 5/1000 [00:00<00:23, 42.02it/s][A[A

  1%|          | 10/1000 [00:00<00:23, 42.68it/s][A[A

  2%|▏         | 15/1000 [00:00<00:22, 43.01it/s][A[A

  2%|▏         | 20/1000 [00:00<00:22, 43.22it/s][A[A

  2%|▎         | 25/1000 [00:00<00:22, 43.20it/s][A[A

  3%|▎         | 30/1000 [00:00<00:22, 42.73it/s][A[A

  4%|▎         | 35/1000 [00:00<00:22, 42.86it/s][A[A

  4%|▍         | 40/1000 [00:00<00:22, 43.03it/s][A[A

  4%|▍         | 45/1000 [00:01<00:22, 42.93it/s][A[A

  5%|▌         | 50/1000 [00:01<00:22, 42.98it/s][A[A

  6%|▌         | 55/1000 [00:01<00:21, 43.10it/s][A[A

  6%|▌         | 60/1000 [00:01<00:21, 43.13it/s][A[A

  6%|▋         | 65/1000 [00:01<00:21, 43.18it/s][A[A

  7%|▋         | 70/1000 [00:01<00:21, 43.13it/s][A[A

  8%|▊         | 75/1000 [00:01<00:21, 43.17it/s][A[A

  8%|▊         | 80/1000 [00:01<00:21, 43.26it/s][A[A