# Importing the necessary Modules

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from dotenv import load_dotenv
from diffusers import StableDiffusionPipeline, DDPMScheduler, UNet2DConditionModel
from diffusers.training_utils import EMAModel
from transformers import CLIPTextModel, CLIPTokenizer
import os
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm

load_dotenv()

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

## Loading Pre Trained Model (Stable Diffusion 3.5 large turbo)

In [4]:
# Load the pixel art dataset from cloud storage bucket
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket("pixel-art-dataset")
blob = bucket.blob("dataset.zip")
blob.download_to_filename("dataset.zip")


In [5]:
# Extract dataset.zip
import zipfile
with zipfile.ZipFile("dataset.zip", "r") as zip_ref:
    zip_ref.extractall()


In [6]:
# Define the Pytorch dataset
class PixelArtDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        # Get image path and text from CSV
        img_path = self.data_frame.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        text = self.data_frame.iloc[idx, 0]

        return image, text

In [7]:
# Define the transform
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [8]:
# Initialize the dataset
dataset = PixelArtDataset(csv_file='dataset/labels.csv',
                                 transform=transform)

In [9]:
# Load the pretrained Stable Diffusion model components
model_id = "stabilityai/stable-diffusion-2-1"
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

model_index.json:   0%|          | 0.00/537 [00:00<?, ?B/s]

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

tokenizer/tokenizer_config.json:   0%|          | 0.00/824 [00:00<?, ?B/s]

text_encoder/config.json:   0%|          | 0.00/633 [00:00<?, ?B/s]

scheduler/scheduler_config.json:   0%|          | 0.00/345 [00:00<?, ?B/s]

tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

(…)ature_extractor/preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer/special_tokens_map.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/939 [00:00<?, ?B/s]

vae/config.json:   0%|          | 0.00/611 [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [10]:
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline.to(device)

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.31.0",
  "_name_or_path": "stabilityai/stable-diffusion-2-1",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [12]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Set up training configurations
optimizer = torch.optim.Adam(unet.parameters(), lr=1e-4)

In [18]:
# Training loop with tqdm for progress display
for epoch in range(1):  # Adjust the number of epochs as needed
    running_loss = 0.0  # Track total loss for the epoch

    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}", unit="batch"):
        images, texts = batch
        images = images.to(device)

        # Tokenize input texts with truncation and padding
        text_inputs = tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        ).to(device)

        # Forward pass through the text encoder
        with torch.no_grad():  # Freeze text encoder for training efficiency
            text_embeddings = text_encoder(**text_inputs).last_hidden_state

        # Add noise to images (4-channel input)
        noise = torch.randn_like(images)  # Shape: [batch_size, 3, H, W]
        noisy_images = torch.cat([images, noise[:, :1, :, :]], dim=1)  # Create 4-channel input

        # Forward pass through the UNet
        noise_pred = unet(noisy_images, timestep=50, encoder_hidden_states=text_embeddings).sample

        # Compute the loss (compare predicted noise with true noise)
        loss = torch.nn.functional.mse_loss(noise_pred, noise[:, :1, :, :])  # Match 1-channel noise for consistency

        # Backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update loss tracking
        running_loss += loss.item()

    # Print average loss for the epoch
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch {epoch + 1} Completed. Average Loss: {avg_loss:.4f}")


Epoch 1:   0%|          | 0/250 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [19]:
pipeline.save_pretrained("fine-tuned-stable-diffusion")

In [20]:
# Generate an image based on a text prompt
prompt = "A futuristic cityscape at sunset"
image = pipeline(prompt).images[0]

# Save the generated image to a file
image.save("generated_image.png")

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [21]:
# Display the generated image
image.show()