In [None]:
import torch
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer
from data_processing import load_jsonl_dataset, download_image, preprocess_image, tokenize_captions
from torch.utils.data import DataLoader, Dataset
import os


In [None]:
class ImageCaptionDataset(Dataset):
    """Custom Dataset for image-caption pairs."""
    def __init__(self, data, tokenizer, images_dir):
        self.data = data
        self.tokenizer = tokenizer
        self.images_dir = images_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_url, caption = self.data[idx]
        image_path = os.path.join(self.images_dir, f"{idx}.jpg")
        if not os.path.exists(image_path):
            download_image(image_url, image_path)
        image = preprocess_image(image_path)
        tokenized_caption = self.tokenizer(caption, return_tensors="pt")
        return image, tokenized_caption



In [None]:
def fine_tune_model(data_loader, model, optimizer, num_epochs, save_path):
    """Fine-tunes the Stable Diffusion model."""
    model.train()
    for epoch in range(num_epochs):
        for i, (images, captions) in enumerate(data_loader):
            optimizer.zero_grad()
            outputs = model(pixel_values=images, input_ids=captions['input_ids'])
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                print(f"Epoch {epoch}, Step {i}, Loss: {loss.item()}")
    model.save_pretrained(save_path)

In [None]:
if __name__ == "__main__":
    # Configuration
    DATASET_PATH = "path/to/dataset.jsonl"
    IMAGES_DIR = "images"
    MODEL_SAVE_PATH = "gs://your-bucket-name/fine-tuned-model"
    BATCH_SIZE = 4
    NUM_EPOCHS = 5

    # Load tokenizer and dataset
    tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-diffusion-v1-4")
    data = load_jsonl_dataset(DATASET_PATH)
    dataset = ImageCaptionDataset(data, tokenizer, IMAGES_DIR)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Load pre-trained model
    model = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-v1-4").to("cuda")

    # Fine-tune the model
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    fine_tune_model(data_loader, model, optimizer, NUM_EPOCHS, MODEL_SAVE_PATH)