In [1]:
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
from diffusers import UNet2DConditionModel, DDPMScheduler,AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm.auto import tqdm
from PIL import Image
import os
import json


  from .autonotebook import tqdm as notebook_tqdm
Using TensorFlow backend.


In [2]:

# Parameters
batch_size = 16
epochs = 2
learning_rate = 1e-4
image_size = 32
data_path = "./coco"
device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:

# Tokenizer and Transforms
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
max_length =77


In [4]:

# Dataset Class
class CocoWithAnnotations(Dataset):
    def __init__(self, path, tokenizer, transform, train=True):
        super().__init__()
        self.path = path
        self.data = None
        self.transform = transform
        self.tokenizer = tokenizer
        self.train = train
        if self.data is None:
            self.open_json()

    def open_json(self):
        if self.train:
            print('======================= Loading training annotations =======================')
            with open(f'{self.path}/annotations/captions_train2014.json', 'r') as stream:
                self.data = json.load(stream)
            self.data = self.data['annotations']
        else:
            print('======================= Loading validation annotations =======================')
            with open(f'{self.path}/annotations/captions_val2014.json', 'r') as stream:
                self.data = json.load(stream)
            self.data = self.data['annotations']
        print('======================= ANNOTATIONS LOADED =======================')
    # Dataset __getitem__ method
    def __getitem__(self, index):
        annot = self.data[index]
        if len(str(annot['image_id'])) < 6:
            rem_0l = 6 - len(str(annot['image_id']))
            rem_0 = '0' * rem_0l
            image = self.transform(
                Image.open(f'{self.path}/train2014/COCO_train2014_000000{rem_0 + str(annot["image_id"])}.jpg').convert('RGB')
            )
        else:
            image = self.transform(
                Image.open(f'{self.path}/train2014/COCO_train2014_000000{annot["image_id"]}.jpg').convert('RGB')
            )

        text_emb = self.tokenizer(
            annot['caption'],
            padding='max_length',
            truncation=True,
            max_length=max_length,  # Ensure correct sequence length
            return_tensors="pt"
        )
        return image, text_emb.input_ids.squeeze(0)

    def __len__(self):
        return len(self.data)

# Dataset and DataLoader
dataset = CocoWithAnnotations(data_path, tokenizer, transform, train=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)




In [5]:

# Models
autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
unet = UNet2DConditionModel(
    sample_size=image_size,
    in_channels=4,
    out_channels=4,
    down_block_types=(
        'DownBlock2D',
        'CrossAttnDownBlock2D',
        'CrossAttnDownBlock2D'
    ),
    up_block_types=(
        'CrossAttnUpBlock2D',
        'CrossAttnUpBlock2D',
        'UpBlock2D'
    ),
    block_out_channels=(64, 128, 256),
    cross_attention_dim=512  # Ensure this matches CLIP hidden_dim
).to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)


In [6]:

# Scheduler
diffusion = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")

# Optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)


In [7]:

# Training Loop
for epoch in range(epochs):
    unet.train()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
    for images, captions in progress_bar:
        images = images.to(device)
        captions = captions.to(device)

        # Encode images to latent space
        latents = autoencoder.encode(images).latent_dist.sample() * 0.18215

        # Forward diffusion
        noise = torch.randn_like(latents).to(device)
        timesteps = torch.randint(0, diffusion.num_train_timesteps, (latents.size(0),), device=device).long()
        noisy_latents = diffusion.add_noise(latents, noise, timesteps)

        # Encode text
        text_embeds = text_encoder(captions).last_hidden_state

        # Predict noise
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeds).sample

        # Loss
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix(loss=loss.item())

    print(f"Epoch {epoch + 1} completed with loss: {loss.item():.4f}")

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
Epoch 1/2: 100%|██████████| 25883/25883 [1:28:19<00:00,  4.88it/s, loss=0.00679]


Epoch 1 completed with loss: 0.0068


Epoch 2/2: 100%|██████████| 25883/25883 [1:33:57<00:00,  4.59it/s, loss=0.00188]

Epoch 2 completed with loss: 0.0019





In [8]:
# Save the model
output_dir = "./latent_diffusion_model"
os.makedirs(output_dir, exist_ok=True)
unet.save_pretrained(output_dir)
text_encoder.save_pretrained(output_dir)
autoencoder.save_pretrained(output_dir)

diffusion.save_config(os.path.join(output_dir, "scheduler_config.json"))

print("Model saved successfully!")

Model saved successfully!
