# Importing the necessary Modules

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from dotenv import load_dotenv
from diffusers import StableDiffusionPipeline, DDPMScheduler, UNet2DConditionModel
from diffusers.training_utils import EMAModel
from transformers import CLIPTextModel, CLIPTokenizer
import os
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm

load_dotenv()

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

## Loading Pre Trained Model (Stable Diffusion 3.5 large turbo)

In [2]:
# Load the pixel art dataset from cloud storage bucket
from google.cloud import storage
client = storage.Client()
bucket = client.get_bucket("pixel-art-dataset")
blob = bucket.blob("dataset.zip")
blob.download_to_filename("dataset.zip")


ModuleNotFoundError: No module named 'google.cloud'

In [5]:
# Extract dataset.zip
import zipfile
with zipfile.ZipFile("dataset.zip", "r") as zip_ref:
    zip_ref.extractall()


In [3]:
# Define the Pytorch dataset
class PixelArtDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        # Get image path and text from CSV
        img_path = self.data_frame.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        text = self.data_frame.iloc[idx, 0]

        return image, text

In [4]:
# Define the transform
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [5]:
# Initialize the dataset
dataset = PixelArtDataset(csv_file='dataset/labels.csv',
                                 transform=transform)

In [7]:
# Load the pretrained Stable Diffusion model components
model_id = "stabilityai/stable-diffusion-2-1"
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/1.36G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.hf.co/repos/b4/71/b47143176d3790e957485b59cc13cf072a4b2cbe3340d1b8fa86f53d7197236f/cce6febb0b6d876ee5eb24af35e27e764eb4f9b1d0b7c026c8c3333d4cfc916c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1733318342&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMzMxODM0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9iNC83MS9iNDcxNDMxNzZkMzc5MGU5NTc0ODViNTljYzEzY2YwNzJhNGIyY2JlMzM0MGQxYjhmYTg2ZjUzZDcxOTcyMzZmL2NjZTZmZWJiMGI2ZDg3NmVlNWViMjRhZjM1ZTI3ZTc2NGViNGY5YjFkMGI3YzAyNmM4YzMzMzNkNGNmYzkxNmM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=V3uL1r%7EtnnOTapm5-DD7nYmc7oiY9E1iH-2xa1ohdRNPCQxI30MMmTmOJLWjNoxQ5jtHyfmgwh0HJVeUrNpK3uGBkremJ0fgU-iRjt%7EQ84ipl%7EF-jKQxxiGgq4vtWd-zopUtQxvBaUqabDKlZgQrMtI1373nAoLo-CP5OMw-46guZVl8gzzY4CSmM-4wwIr4UIfxaMGk6r7MAfze4PHCNFcNcFpGfW7FWycFXUcXDf7Ma%7EUsmTukwoBNXE0qMXO0PtJ

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.hf.co/repos/b4/71/b47143176d3790e957485b59cc13cf072a4b2cbe3340d1b8fa86f53d7197236f/1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27diffusion_pytorch_model.safetensors%3B+filename%3D%22diffusion_pytorch_model.safetensors%22%3B&Expires=1733318342&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMzMxODM0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9iNC83MS9iNDcxNDMxNzZkMzc5MGU5NTc0ODViNTljYzEzY2YwNzJhNGIyY2JlMzM0MGQxYjhmYTg2ZjUzZDcxOTcyMzZmLzEyMzg1MjIyNzdjNDg5MjNmZjI3NTFlMjM4ZjI3NDJjNTYyZTQ1NjQzZjNkNTBjYzkzZDE2M2NiMzA2MzhiMGM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=gO2nF3gXVmW9YKcpRuGq5X-D8fH5Qp%7EvlXyVaN07zMtZ3r9gmSRX3UgJYsJkf7eKXcs3ojPD2Me2Vt1t2A2ghEkeimxm9ZlA06I%7EzFFWYtVnGSLxXZkCHBCUf%7EFGqHfLxsTA8WpzeGJ1qmyHqSl6mr0HbTvonLyNy3Sfz3xNSMegOe1kJhMHeXiR5YdPvh6Xt9JE2EjcVdHPvtWhKALZhimsfEhUs-vjb

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.46G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.hf.co/repos/b4/71/b47143176d3790e957485b59cc13cf072a4b2cbe3340d1b8fa86f53d7197236f/1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27diffusion_pytorch_model.safetensors%3B+filename%3D%22diffusion_pytorch_model.safetensors%22%3B&Expires=1733318342&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMzMxODM0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9iNC83MS9iNDcxNDMxNzZkMzc5MGU5NTc0ODViNTljYzEzY2YwNzJhNGIyY2JlMzM0MGQxYjhmYTg2ZjUzZDcxOTcyMzZmLzEyMzg1MjIyNzdjNDg5MjNmZjI3NTFlMjM4ZjI3NDJjNTYyZTQ1NjQzZjNkNTBjYzkzZDE2M2NiMzA2MzhiMGM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=gO2nF3gXVmW9YKcpRuGq5X-D8fH5Qp%7EvlXyVaN07zMtZ3r9gmSRX3UgJYsJkf7eKXcs3ojPD2Me2Vt1t2A2ghEkeimxm9ZlA06I%7EzFFWYtVnGSLxXZkCHBCUf%7EFGqHfLxsTA8WpzeGJ1qmyHqSl6mr0HbTvonLyNy3Sfz3xNSMegOe1kJhMHeXiR5YdPvh6Xt9JE2EjcVdHPvtWhKALZhimsfEhUs-vjb

diffusion_pytorch_model.safetensors:  25%|##4       | 849M/3.46G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.hf.co/repos/b4/71/b47143176d3790e957485b59cc13cf072a4b2cbe3340d1b8fa86f53d7197236f/1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27diffusion_pytorch_model.safetensors%3B+filename%3D%22diffusion_pytorch_model.safetensors%22%3B&Expires=1733318342&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMzMxODM0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9iNC83MS9iNDcxNDMxNzZkMzc5MGU5NTc0ODViNTljYzEzY2YwNzJhNGIyY2JlMzM0MGQxYjhmYTg2ZjUzZDcxOTcyMzZmLzEyMzg1MjIyNzdjNDg5MjNmZjI3NTFlMjM4ZjI3NDJjNTYyZTQ1NjQzZjNkNTBjYzkzZDE2M2NiMzA2MzhiMGM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=gO2nF3gXVmW9YKcpRuGq5X-D8fH5Qp%7EvlXyVaN07zMtZ3r9gmSRX3UgJYsJkf7eKXcs3ojPD2Me2Vt1t2A2ghEkeimxm9ZlA06I%7EzFFWYtVnGSLxXZkCHBCUf%7EFGqHfLxsTA8WpzeGJ1qmyHqSl6mr0HbTvonLyNy3Sfz3xNSMegOe1kJhMHeXiR5YdPvh6Xt9JE2EjcVdHPvtWhKALZhimsfEhUs-vjb

diffusion_pytorch_model.safetensors:  58%|#####7    | 1.99G/3.46G [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs.hf.co/repos/b4/71/b47143176d3790e957485b59cc13cf072a4b2cbe3340d1b8fa86f53d7197236f/1238522277c48923ff2751e238f2742c562e45643f3d50cc93d163cb30638b0c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27diffusion_pytorch_model.safetensors%3B+filename%3D%22diffusion_pytorch_model.safetensors%22%3B&Expires=1733318342&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMzMxODM0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9iNC83MS9iNDcxNDMxNzZkMzc5MGU5NTc0ODViNTljYzEzY2YwNzJhNGIyY2JlMzM0MGQxYjhmYTg2ZjUzZDcxOTcyMzZmLzEyMzg1MjIyNzdjNDg5MjNmZjI3NTFlMjM4ZjI3NDJjNTYyZTQ1NjQzZjNkNTBjYzkzZDE2M2NiMzA2MzhiMGM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=gO2nF3gXVmW9YKcpRuGq5X-D8fH5Qp%7EvlXyVaN07zMtZ3r9gmSRX3UgJYsJkf7eKXcs3ojPD2Me2Vt1t2A2ghEkeimxm9ZlA06I%7EzFFWYtVnGSLxXZkCHBCUf%7EFGqHfLxsTA8WpzeGJ1qmyHqSl6mr0HbTvonLyNy3Sfz3xNSMegOe1kJhMHeXiR5YdPvh6Xt9JE2EjcVdHPvtWhKALZhimsfEhUs-vjb

ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='cdn-lfs.hf.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: f648a89c-5dda-42a8-9b91-8c722ba4f113)')

In [None]:
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline.to(device)

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.31.0",
  "_name_or_path": "stabilityai/stable-diffusion-2-1",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "DDIMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [None]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# Set up training configurations
optimizer = torch.optim.Adam(unet.parameters(), lr=1e-4)

In [None]:
for epoch in range(1):
    for batch in tqdm(dataloader):
        images, texts = batch
        images = images.to(device)

        # Tokenize the input texts
        text_inputs = tokenizer(texts, padding="max_length", max_length=77, return_tensors="pt").to(device)
        text_embeddings = text_encoder(**text_inputs).last_hidden_state

        # Add noise and create 4-channel input
        noise = torch.randn_like(images)  # Shape: [batch_size, 3, 128, 128]
        noisy_images = torch.cat([images, noise[:, 0:1, :, :]], dim=1)  # Shape: [batch_size, 4, 128, 128]

        # Forward pass through UNet
        noise_pred = unet(noisy_images, timestep=50, encoder_hidden_states=text_embeddings).sample

        # Compute the loss (simplified example)
        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


  0%|          | 0/250 [00:00<?, ?it/s]

In [None]:
pipeline.save_pretrained("fine-tuned-stable-diffusion")