In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import os
from torchvision import transforms
from PIL import Image
import pytorch_lightning as pl
from diffusers import UNet2DModel, DDPMScheduler
import matplotlib.pyplot as plt
from tqdm import tqdm

# Load the train split from the Hugging Face dataset
dataset = load_dataset('ta4tsering/Lhasa_kanjur_transcription_datasets', split='train')

# Extract filenames and transcriptions
filenames = dataset['filename']
transcriptions = dataset['label']

# Load pre-trained model and tokenizer for text embeddings
tokenizer = AutoTokenizer.from_pretrained('openpecha/tibetan_RoBERTa_S_e6')
model_roberta = AutoModel.from_pretrained('openpecha/tibetan_RoBERTa_S_e6').to('cuda')

# Function to batch convert texts to vectors using GPU
def texts_to_vectors(texts, batch_size=32):
    vectors = []
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True).to('cuda')
        with torch.no_grad():
            outputs = model_roberta(**inputs)
        batch_vectors = outputs.last_hidden_state.mean(dim=1).detach().cpu().numpy()
        vectors.extend(batch_vectors)
    return vectors

# Pre-compute text embeddings in batches using GPU
text_vectors = texts_to_vectors(transcriptions)

# Custom dataset class with pre-computed text embeddings
class CustomDataset(Dataset):
    def __init__(self, image_dir, filenames, text_vectors, transform=None):
        self.image_dir = image_dir
        self.filenames = filenames
        self.text_vectors = text_vectors
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.filenames[idx])
        image = Image.open(img_name).convert("L")  # Convert to grayscale
        if self.transform:
            image = self.transform(image)
        text_vector = self.text_vectors[idx]
        return image, text_vector

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((64, 2048)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize for grayscale images
])

# Define the dataset and dataloader
image_dir = '/local_dir/Train_Images'  # Adjust to your local directory
dataset = CustomDataset(image_dir=image_dir, filenames=filenames, text_vectors=text_vectors, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

2024-07-22 12:33:49.232252: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-22 12:33:49.249226: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-22 12:33:49.269183: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-22 12:33:49.275198: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-22 12:33:49.290110: I tensorflow/core/platform/cpu_feature_guar

In [4]:
import torch.nn as nn
# Define the adjusted conditional DDPM model
class SimpleConditionalDDPM(pl.LightningModule):
    def __init__(self):
        super(SimpleConditionalDDPM, self).__init__()
        self.model = UNet2DModel(
            sample_size=(64, 2048),
            in_channels=2,  # Update to accept 2 channels (image + text)
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(64, 128, 256, 512),
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D")
        )
        self.text_embedding = nn.Linear(768, 64 * 2048)  # Adjust dimensions as needed
        self.scheduler = DDPMScheduler(num_train_timesteps=1000)
        self.criterion = torch.nn.MSELoss()

    def forward(self, x, t, text_vector):
        # Convert text_vector to match image dimensions
        text_embedding = self.text_embedding(text_vector).view(-1, 1, 64, 2048)
        # Concatenate image and text embeddings
        x = torch.cat((x, text_embedding), dim=1)
        return self.model(x, t).sample

    def training_step(self, batch, batch_idx):
        images, text_vectors = batch
        images = images.to(self.device)
        text_vectors = torch.tensor(text_vectors).to(self.device)
        t = torch.randint(0, self.scheduler.config.num_train_timesteps, (images.size(0),), device=self.device).long()
        noise = torch.randn_like(images).to(self.device)
        noisy_images = self.scheduler.add_noise(original_samples=images, noise=noise, timesteps=t)

        predicted_noise = self(noisy_images, t, text_vectors)
        loss = self.criterion(predicted_noise, noise)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

# Initialize the conditional model
model = SimpleConditionalDDPM()

# Define the trainer
trainer = pl.Trainer(
    accumulate_grad_batches=4,  # Gradient accumulation
    precision=16,  # Mixed precision
    max_epochs=100,
    accelerator='gpu',
    devices=1
)

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# Train the model
trainer.fit(model, dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type        | Params | Mode
------------------------------------------------------
0 | model          | UNet2DModel | 56.6 M | eval
1 | text_embedding | Linear      | 100 M  | eval
2 | criterion      | MSELoss     | 0      | eval
------------------------------------------------------
157 M     Trainable params
0         Non-trainable params
157 M     Total params
629.469   Total estimated model params size (MB)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment va

Training: |          | 0/? [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [None]:
import matplotlib.pyplot as plt
import random
from tqdm import tqdm

# Function to generate conditional images from noise using a random text vector from training data
def generate_conditional_images(model, scheduler, dataset, num_images=8, device='cuda'):
    model.eval()
    fig, axs = plt.subplots(2 * num_images, 1, figsize=(20, 5))
    with torch.no_grad():
        for i in tqdm(range(num_images)):
            # Select a random image and text vector pair from the dataset
            random_idx = 0
            image, random_text_vector = dataset[random_idx]
            random_text_vector = torch.tensor(random_text_vector).to(device)
            noise = torch.randn((1, 1, 64, 2048)).to(device)
            for t in tqdm(range(scheduler.config.num_train_timesteps - 1, -1, -1), leave=False):
                t_tensor = torch.tensor([t], device=device).long()
                predicted_noise = model(noise, t_tensor, random_text_vector)
                noise = scheduler.step(model_output=predicted_noise, timestep=t_tensor, sample=noise).prev_sample

            # Display the generated image
            axs[2 * i].imshow(noise.squeeze().cpu().numpy(), cmap='gray')
            axs[2 * i].set_title("Generated Image")
            axs[2 * i].axis('off')

            # Display the ground truth image
            axs[2 * i + 1].imshow(image.permute(1, 2, 0).cpu().numpy(), cmap='gray')
            axs[2 * i + 1].set_title("Ground Truth Image")
            axs[2 * i + 1].axis('off')

    plt.tight_layout()
    plt.show()

# Generate and display conditional images using random text vectors from the training data
generate_conditional_images(model, model.scheduler, dataset, num_images=4, device='cuda')


  0%|          | 0/4 [00:00<?, ?it/s][A

  0%|          | 0/1000 [00:00<?, ?it/s][A[A

  0%|          | 5/1000 [00:00<00:22, 44.17it/s][A[A

  1%|          | 10/1000 [00:00<00:22, 43.35it/s][A[A

  2%|▏         | 15/1000 [00:00<00:22, 42.91it/s][A[A

  2%|▏         | 20/1000 [00:00<00:22, 42.81it/s][A[A

  2%|▎         | 25/1000 [00:00<00:22, 42.82it/s][A[A

  3%|▎         | 30/1000 [00:00<00:22, 42.72it/s][A[A

  4%|▎         | 35/1000 [00:00<00:22, 42.70it/s][A[A

  4%|▍         | 40/1000 [00:00<00:22, 42.71it/s][A[A

  4%|▍         | 45/1000 [00:01<00:22, 42.69it/s][A[A

  5%|▌         | 50/1000 [00:01<00:22, 42.69it/s][A[A

  6%|▌         | 55/1000 [00:01<00:22, 42.43it/s][A[A

  6%|▌         | 60/1000 [00:01<00:22, 42.58it/s][A[A

  6%|▋         | 65/1000 [00:01<00:21, 42.62it/s][A[A

  7%|▋         | 70/1000 [00:01<00:21, 42.64it/s][A[A

  8%|▊         | 75/1000 [00:01<00:21, 42.62it/s][A[A

  8%|▊         | 80/1000 [00:01<00:21, 42.63it/s][A[A