# Run the code in this notebook to finetune the model at a certain epoch

In [8]:
import torch
from torch.utils.data import DataLoader,Dataset
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm.auto import tqdm
import os
import json
from PIL import Image
from torchvision.transforms import transforms


In [9]:

# Parameters
batch_size = 8
EPOCHS = 2  # SET IT ACCORDINGLY
learning_rate = 1e-4
image_size = 64
latent_size = image_size // 8
data_path = "/teamspace/studios/this_studio/coco"
checkpoint_to_resume = "./ldm_checkpoints/epoch_2"  # Specify your checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer and Transforms
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
max_length = 77



In [10]:
# Dataset class (same as your training script)
class CocoWithAnnotations(Dataset):
    def __init__(self, path, tokenizer, transform, train=True):
        super().__init__()
        self.path = path
        self.data = None
        self.transform = transform
        self.tokenizer = tokenizer
        self.train = train
        if self.data is None:
            self.open_json()

    def open_json(self):
        split = "train" if self.train else "val"
        with open(f'{self.path}/annotations/captions_{split}2014.json', 'r') as stream:
            self.data = json.load(stream)['annotations']

    def __getitem__(self, index):
        annot = self.data[index]
        img_id = str(annot['image_id']).zfill(12)
        split = "train" if self.train else "val"
        image_path = f'{self.path}/{split}2014/COCO_{split}2014_{img_id}.jpg'
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        text_emb = self.tokenizer(
            annot['caption'],
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )
        return image, text_emb.input_ids.squeeze(0)

    def __len__(self):
        return len(self.data)

# DataLoader
def collate_fn(batch):
    batch = [(img, cap) for img, cap in batch if img is not None and cap is not None]
    if not batch:
        return None
    images, captions = zip(*batch)
    return torch.stack(images), torch.stack(captions)

dataset = CocoWithAnnotations(data_path, tokenizer, transform, train=True)
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

In [11]:
# Models
unet = UNet2DConditionModel.from_pretrained(os.path.join(checkpoint_to_resume, "unet")).to(device)
autoencoder = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
diffusion = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")

# Freeze VAE and text encoder
for param in autoencoder.parameters():
    param.requires_grad = False
for param in text_encoder.parameters():
    param.requires_grad = False

# Optimizer
optimizer = torch.optim.AdamW(unet.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()

# Load checkpoint
checkpoint = torch.load(os.path.join(checkpoint_to_resume, "training_state.pth"))
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
START_EPOCH = checkpoint['epoch'] + 1  # Start from the next epoch

# Training Loop
for epoch in range(START_EPOCH, START_EPOCH + EPOCHS):
    unet.train()
    epoch_loss = 0
    valid_batches = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{START_EPOCH + EPOCHS}")
    for batch in progress_bar:
        if batch is None:
            continue

        images, captions = batch
        images = images.to(device)
        captions = captions.to(device)

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                latents = autoencoder.encode(images).latent_dist.sample()
                latents = latents * 0.18215

            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, diffusion.config.num_train_timesteps, (latents.shape[0],), device=device)
            noisy_latents = diffusion.add_noise(latents, noise, timesteps)

            with torch.no_grad():
                encoder_hidden_states = text_encoder(captions)[0]

            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="mean")

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(unet.parameters(), 0.5)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        epoch_loss += loss.item()
        valid_batches += 1
        progress_bar.set_postfix({"loss": loss.item()})

    avg_loss = epoch_loss / valid_batches
    print(f"Epoch {epoch + 1} completed with average loss: {avg_loss:.4f}")

    # Save checkpoint
    checkpoint_dir = f"./ldm_checkpoints/epoch_{epoch + 1}"
    os.makedirs(checkpoint_dir, exist_ok=True)
    unet.save_pretrained(os.path.join(checkpoint_dir, "unet"))
    torch.save({
        'epoch': epoch,
        'model_state_dict': unet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, os.path.join(checkpoint_dir, "training_state.pth"))

print("Training resumed and completed!")


Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.


Epoch 3/4:   0%|          | 0/51765 [00:00<?, ?it/s]

Epoch 3 completed with average loss: 0.1888


Epoch 4/4:   0%|          | 0/51765 [00:00<?, ?it/s]

Epoch 4 completed with average loss: 0.1875
Training resumed and completed!
