In [None]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from diffusers import StableDiffusionImg2ImgPipeline
import math
from PIL import Image
import os
from torch.utils.data import Dataset , DataLoader
import torch.nn.functional as F
from torch.cuda.amp import autocast


In [None]:
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id , torch_dtype=torch.float16).to("cuda")

In [None]:
class LoraLAYER(nn.Module):
  def __init__(self , original , rank = 32 , alpha = 32):
    super().__init__()  
    self.original = original
    in_feat = original.in_features
    out_feat = original.out_features
    
    self.lora_A = torch.nn.Parameter(torch.randn(in_feat, rank))

    self.lora_B = torch.nn.Parameter(torch.zeros(rank, out_feat))
    self.scaling = alpha / rank
  def forward(self , x):
    return self.original(x) + (torch.matmul(torch.matmul(x, self.lora_A), self.lora_B)) * self.scaling

In [None]:
pipe.unet.requires_grad_(False)
pipe.unet.to(device="cuda", dtype=torch.float16)
for name , module in pipe.unet.named_modules():
  if "attn" in name and isinstance(module , nn.Linear):
    parent_name = ".".join(name.split(".")[:-1])
    layer_name = name.split(".")[-1]
    parent = pipe.unet.get_submodule(parent_name)

    new = LoraLAYER(module, rank=64, alpha=64)
    new.to(device = "cuda" , dtype=torch.float16)
    setattr(parent, layer_name, new)

In [None]:

class DatasetA(Dataset):
  def __init__(self , img_Dir , cap_Dir , tokenizer , size = 512):
    super().__init__()
    self.img_Dir = img_Dir
    self.cap_Dir = cap_Dir
    self.tokenizer = tokenizer
    self.images = [f for f in os.listdir(img_Dir) if f.endswith(('.png', '.jpg'))]
    self.transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5] , [0.5])
    ])
  def __len__(self): return len(self.images)
  def __getitem__(self , idx):
    image_name = self.images[idx]
    image_path = os.path.join(self.img_Dir , self.images[idx])
    image = self.transform(Image.open(image_path).convert("RGB"))
    
  
    txt_filename = os.path.splitext(image_name)[0] + ".txt"
    txt_path = os.path.join(self.cap_Dir, txt_filename)
    with open(txt_path, 'r') as f:
      caption = f.read().strip()

    tokens = self.tokenizer(caption, padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
    return image, tokens
          

In [None]:

trainDATA = DatasetA(img_Dir =  "/content/data/anime_images", cap_Dir =  "/content/data/info" , tokenizer = pipe.tokenizer)
trainLOADER = DataLoader(trainDATA , batch_size = 4 , shuffle = True)

In [None]:


pipe.unet.to(dtype=torch.float32)

param_main = [p for p in pipe.unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(param_main, lr=5e-6)

pipe.vae.to(dtype=torch.float32)
pipe.text_encoder.to(dtype=torch.float32)
pipe.unet.to(dtype=torch.float32)
pipe.unet.train() 
for param in pipe.unet.parameters():
    param.data = param.data.to(torch.float32)
    if param.grad is not None:
        param.grad.data = param.grad.data.to(torch.float32)


param_main = [p for p in pipe.unet.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(param_main, lr=1e-5)

for step, (pixel, index) in enumerate(trainLOADER):
    # Inputs to GPU + Float32
    pixel = pixel.to("cuda", dtype=torch.float32)
    index = index.to("cuda")

    with torch.no_grad():
        # Latents must be Float32
        latents = pipe.vae.encode(pixel).latent_dist.sample() * 0.18215
        # Embeddings must be Float32
        encoder_hidden_states = pipe.text_encoder(index)[0]

    noise = torch.randn_like(latents).to("cuda", dtype=torch.float32)
    timestep = torch.randint(0, 1000, (latents.shape[0],), device="cuda").long()
    noisy_latents = pipe.scheduler.add_noise(latents, noise, timestep)

    # UNet Pass
    model_pred = pipe.unet(
        noisy_latents, 
        timestep, 
        encoder_hidden_states=encoder_hidden_states
    ).sample

    loss = F.mse_loss(model_pred, noise, reduction="mean")
    
    loss.backward()
    torch.nn.utils.clip_grad_norm_(param_main, 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if step % 10 == 0:
        print(f"Step {step} | Loss: {loss.item():.4f}")