In [1]:
import torch
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm.auto import tqdm
from PIL import Image
import os
import json


  from .autonotebook import tqdm as notebook_tqdm
Using TensorFlow backend.


In [2]:

# Parameters
batch_size = 8  # Reduced batch size for stability
epochs = 2
learning_rate = 1e-4
image_size = 256  # Increased image size
latent_size = image_size // 8  # Calculate latent size (512 -> 64)
data_path = "./coco"
device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:

# Tokenizer and Transforms
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
max_length = 77


In [4]:

class CocoWithAnnotations(Dataset):
    def __init__(self, path, tokenizer, transform, train=True):
        super().__init__()
        self.path = path
        self.data = None
        self.transform = transform
        self.tokenizer = tokenizer
        self.train = train
        if self.data is None:
            self.open_json()

    def open_json(self):
        split = "train" if self.train else "val"
        print(f'Loading {split} annotations...')
        with open(f'{self.path}/annotations/captions_{split}2014.json', 'r') as stream:
            self.data = json.load(stream)['annotations']
        print('Annotations loaded')

    def __getitem__(self, index):
        annot = self.data[index]
        img_id = str(annot['image_id']).zfill(12)
        split = "train" if self.train else "val"
        
        try:
            image_path = f'{self.path}/{split}2014/COCO_{split}2014_{img_id}.jpg'
            image = Image.open(image_path).convert('RGB')
            image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            # Return a default item in case of error
            return torch.zeros((3, image_size, image_size)), torch.zeros((max_length,))

        text_emb = self.tokenizer(
            annot['caption'],
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        return image, text_emb.input_ids.squeeze(0)

    def __len__(self):
        return len(self.data)

In [5]:
# Models with proper configuration
unet = UNet2DConditionModel(
    sample_size=latent_size,  # Use calculated latent size
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(128, 256, 512, 512),
    down_block_types=(
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
    ),
    cross_attention_dim=768,  # Match CLIP's hidden dimension
).to(device)

# Load pretrained models
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
diffusion = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")

# Freeze VAE and text encoder
for param in autoencoder.parameters():
    param.requires_grad = False
for param in text_encoder.parameters():
    param.requires_grad = False

# Dataset and DataLoader with error handling
def collate_fn(batch):
    # Filter out any None values from failed loads
    batch = [(img, cap) for img, cap in batch if img is not None and cap is not None]
    if not batch:
        return None
    images, captions = zip(*batch)
    return torch.stack(images), torch.stack(captions)

dataset = CocoWithAnnotations(data_path, tokenizer, transform, train=True)
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)



Loading train annotations...
Annotations loaded


In [None]:
# Optimizer with gradient clipping
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()  # For mixed precision training

# Training loop with improved error handling and mixed precision
for epoch in range(epochs):
    unet.train()
    epoch_loss = 0
    valid_batches = 0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
    for batch in progress_bar:
        if batch is None:
            continue
            
        images, captions = batch
        images = images.to(device)
        captions = captions.to(device)

        with torch.cuda.amp.autocast():
            # Encode images
            with torch.no_grad():
                latents = autoencoder.encode(images).latent_dist.sample()
                latents = latents * 0.18215

            # Noise augmentation
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, diffusion.config.num_train_timesteps, (latents.shape[0],), device=device)
            noisy_latents = diffusion.add_noise(latents, noise, timesteps)

            # Get text embeddings
            with torch.no_grad():
                encoder_hidden_states = text_encoder(captions)[0]

            # Predict noise
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="mean")

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        epoch_loss += loss.item()
        valid_batches += 1
        progress_bar.set_postfix({"loss": loss.item()})

    avg_loss = epoch_loss / valid_batches
    print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}")

    # Save checkpoint after each epoch
    checkpoint_dir = f"./ldm_checkpoints/epoch_{epoch + 1}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    unet.save_pretrained(os.path.join(checkpoint_dir, "unet"))
    torch.save({
        'epoch': epoch,
        'model_state_dict': unet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, os.path.join(checkpoint_dir, "training_state.pth"))

print("Training completed!")

  scaler = torch.cuda.amp.GradScaler()  # For mixed precision training
Epoch 1/2:   0%|          | 0/51765 [00:00<?, ?it/s]