# Run the code in this notebook to finetune the model at a certain epoch

In [1]:
import torch
from torch.utils.data import DataLoader,Dataset
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm.auto import tqdm
import os
import json
from PIL import Image
from torchvision.transforms import transforms


In [2]:

# Parameters
batch_size = 8
EPOCHS = 4  # SET IT ACCORDINGLY
learning_rate = 1e-4
image_size = 128
latent_size = image_size // 16
data_path = "/teamspace/studios/this_studio/coco"
checkpoint_to_resume = "./ldm_checkpoints/epoch_6"  # Specify your checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer and Transforms
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
max_length = 77



In [3]:
# Dataset class (same as your training script)
class CocoWithAnnotations(Dataset):
    def __init__(self, path, tokenizer, transform, train=True):
        super().__init__()
        self.path = path
        self.data = None
        self.transform = transform
        self.tokenizer = tokenizer
        self.train = train
        if self.data is None:
            self.open_json()

    def open_json(self):
        split = "train" if self.train else "val"
        with open(f'{self.path}/annotations/captions_{split}2014.json', 'r') as stream:
            self.data = json.load(stream)['annotations']

    def __getitem__(self, index):
        annot = self.data[index]
        img_id = str(annot['image_id']).zfill(12)
        split = "train" if self.train else "val"
        image_path = f'{self.path}/{split}2014/COCO_{split}2014_{img_id}.jpg'
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        text_emb = self.tokenizer(
            annot['caption'],
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        return image, text_emb.input_ids.squeeze(0)

    def __len__(self):
        return len(self.data)

# DataLoader
def collate_fn(batch):
    batch = [(img, cap) for img, cap in batch if img is not None and cap is not None]
    if not batch:
        return None
    images, captions = zip(*batch)
    return torch.stack(images), torch.stack(captions)

dataset = CocoWithAnnotations(data_path, tokenizer, transform, train=True)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [4]:
# Models
unet = UNet2DConditionModel.from_pretrained(os.path.join(checkpoint_to_resume, "unet")).to(device)
autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
diffusion = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")

# Freeze VAE and text encoder
for param in autoencoder.parameters():
    param.requires_grad = False
for param in text_encoder.parameters():
    param.requires_grad = False

# Optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()

# Load checkpoint
checkpoint = torch.load(os.path.join(checkpoint_to_resume, "training_state.pth"))
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
START_EPOCH = checkpoint['epoch'] + 1  # Start from the next epoch

# Training Loop
for epoch in range(START_EPOCH, START_EPOCH + EPOCHS):
    unet.train()
    epoch_loss = 0
    valid_batches = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{START_EPOCH + EPOCHS}")
    for step, batch in enumerate(progress_bar):
        if batch is None:
            print(f"[WARN] Batch {step}: NoneType batch skipped.")
            continue

        images, captions = batch
        images = images.to(device, non_blocking=True)
        captions = captions.to(device, non_blocking=True)

        # Sanity check on image inputs
        if not torch.isfinite(images).all():
            print(f"[ERROR] Non-finite values in images at step {step}, skipping batch.")
            continue

        with torch.cuda.amp.autocast():
            try:
                with torch.no_grad():
                    latents = autoencoder.encode(images).latent_dist.sample()
                    latents = latents * 0.18215

                if not torch.isfinite(latents).all():
                    print(f"[ERROR] NaN/Inf in latents at step {step}, skipping batch.")
                    continue

                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0, diffusion.config.num_train_timesteps, 
                    (latents.shape[0],), device=device
                )

                noisy_latents = diffusion.add_noise(latents, noise, timesteps)

                if not torch.isfinite(noisy_latents).all():
                    print(f"[ERROR] NaN/Inf in noisy_latents at step {step}, skipping batch.")
                    continue

                with torch.no_grad():
                    encoder_hidden_states = text_encoder(captions)[0]

                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            except Exception as e:
                print(f"[EXCEPTION] Error at step {step}: {e}")
                continue

        # Force loss to run in full precision (helps avoid AMP overflow)
        with torch.cuda.amp.autocast(enabled=False):
            loss = torch.nn.functional.mse_loss(
                noise_pred.float(), noise.float(), reduction="mean"
            )

        # Detect NaN loss early
        if not torch.isfinite(loss):
            print(f"[WARN] Non-finite loss detected (step {step}), skipping optimizer step.")
            optimizer.zero_grad(set_to_none=True)
            scaler = torch.cuda.amp.GradScaler()
            continue

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(unet.parameters(), 0.5)

        try:
            scaler.step(optimizer)
            scaler.update()
        except Exception as e:
            print(f"[EXCEPTION] Optimizer/Scaler step failed at step {step}: {e}")
            optimizer.zero_grad(set_to_none=True)
            continue

        optimizer.zero_grad(set_to_none=True)

        epoch_loss += loss.item()
        valid_batches += 1
        progress_bar.set_postfix({"loss": f"{loss.item():.6f}"})

        # Optional: print detailed debug info occasionally
        # if step % 500 == 0:
        #     print(f"[DEBUG] Step {step}: loss={loss.item():.6f}, scale={scaler.get_scale()}")

    avg_loss = epoch_loss / max(valid_batches, 1)
    print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.6f}")

    # Save checkpoint every epoch
    checkpoint_dir = f"./ldm_checkpoints/epoch_{epoch + 1}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    unet.save_pretrained(os.path.join(checkpoint_dir, "unet"))
    torch.save({
        'epoch': epoch,
        'model_state_dict': unet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss, # for logging
    }, os.path.join(checkpoint_dir, "training_state.pth"))

print("Training resumed and completed!")



  scaler = torch.cuda.amp.GradScaler()


Epoch 7/10:   0%|          | 0/51765 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast(enabled=False):


[WARN] Non-finite loss detected (step 28507), skipping optimizer step.


  scaler = torch.cuda.amp.GradScaler()


[WARN] Non-finite loss detected (step 43397), skipping optimizer step.
[WARN] Non-finite loss detected (step 43984), skipping optimizer step.
[WARN] Non-finite loss detected (step 43985), skipping optimizer step.
[WARN] Non-finite loss detected (step 43992), skipping optimizer step.
[WARN] Non-finite loss detected (step 44054), skipping optimizer step.
[WARN] Non-finite loss detected (step 44127), skipping optimizer step.
[WARN] Non-finite loss detected (step 44130), skipping optimizer step.
[WARN] Non-finite loss detected (step 44131), skipping optimizer step.
[WARN] Non-finite loss detected (step 44132), skipping optimizer step.
[WARN] Non-finite loss detected (step 44133), skipping optimizer step.
[WARN] Non-finite loss detected (step 44134), skipping optimizer step.
[WARN] Non-finite loss detected (step 44135), skipping optimizer step.
[WARN] Non-finite loss detected (step 44136), skipping optimizer step.
[WARN] Non-finite loss detected (step 44137), skipping optimizer step.
[WARN]

Epoch 8/10:   0%|          | 0/51765 [00:00<?, ?it/s]

Epoch 8 completed with average loss: 0.169744


Epoch 9/10:   0%|          | 0/51765 [00:00<?, ?it/s]

[WARN] Non-finite loss detected (step 24719), skipping optimizer step.
[WARN] Non-finite loss detected (step 28714), skipping optimizer step.
[WARN] Non-finite loss detected (step 28723), skipping optimizer step.
[WARN] Non-finite loss detected (step 43379), skipping optimizer step.
[WARN] Non-finite loss detected (step 43380), skipping optimizer step.
[WARN] Non-finite loss detected (step 43381), skipping optimizer step.
[WARN] Non-finite loss detected (step 43383), skipping optimizer step.
[WARN] Non-finite loss detected (step 43384), skipping optimizer step.
[WARN] Non-finite loss detected (step 43385), skipping optimizer step.
[WARN] Non-finite loss detected (step 43386), skipping optimizer step.
[WARN] Non-finite loss detected (step 43387), skipping optimizer step.
[WARN] Non-finite loss detected (step 43388), skipping optimizer step.
[WARN] Non-finite loss detected (step 43389), skipping optimizer step.
[WARN] Non-finite loss detected (step 43390), skipping optimizer step.
[WARN]

Epoch 10/10:   0%|          | 0/51765 [00:00<?, ?it/s]

[WARN] Non-finite loss detected (step 0), skipping optimizer step.
[WARN] Non-finite loss detected (step 1), skipping optimizer step.
[WARN] Non-finite loss detected (step 2), skipping optimizer step.
[WARN] Non-finite loss detected (step 3), skipping optimizer step.
[WARN] Non-finite loss detected (step 4), skipping optimizer step.
[WARN] Non-finite loss detected (step 5), skipping optimizer step.
[WARN] Non-finite loss detected (step 6), skipping optimizer step.
[WARN] Non-finite loss detected (step 7), skipping optimizer step.
[WARN] Non-finite loss detected (step 8), skipping optimizer step.
[WARN] Non-finite loss detected (step 9), skipping optimizer step.
[WARN] Non-finite loss detected (step 10), skipping optimizer step.
[WARN] Non-finite loss detected (step 11), skipping optimizer step.
[WARN] Non-finite loss detected (step 12), skipping optimizer step.
[WARN] Non-finite loss detected (step 13), skipping optimizer step.
[WARN] Non-finite loss detected (step 14), skipping optimi