In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from diffusers import UNet2DModel, DDPMScheduler
from PIL import Image
import pandas as pd
import os

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [15]:
df = pd.read_parquet("../data/processed_sticker_dataset.parquet")
df.head()

Unnamed: 0,combined_embedding,image_path
0,"[0.05615041, 0.06784809, -0.03342954, 0.037553...",../data/tensor_images/AlexatorStickers\cartoon...
1,"[-0.124234326, 0.07463956, -0.011985385, 0.004...",../data/tensor_images/AlexatorStickers\cartoon...
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",../data/tensor_images/AlexatorStickers\cartoon...
3,"[-0.06495428, -0.04292713, 0.013164402, 0.0220...",../data/tensor_images/AlexatorStickers\cartoon...
4,"[0.027918205, 0.075559475, 0.03622711, -0.0181...",../data/tensor_images/AlexatorStickers\cartoon...


In [32]:

# Define transformation
data_transform = transforms.Compose([
    transforms.Resize((32, 32)),  #resize images
    transforms.ToTensor(),    #converting images into PyTorch
    transforms.Normalize([0.5], [0.5])  #normalizing pixel values
])

class StickerDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df        #storing the dataframe containing image
        self.transform = transform  #storing transformation

    def __len__(self):      
        return len(self.df)     #returns the total number of images in dataset

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]["image_path"]  #fetching the row at index idx from dataframe df and retrieving the file path of that image
        
        if img_path.endswith(".pt"): 
            image = torch.load(img_path)
        else:
            image = Image.open(img_path).convert("RGB")  #opening the image file using open method from pillow library
            if self.transform:
                image = self.transform(image)
        return image, 0  # Dummy label

# Create DataLoader
dataset = StickerDataset(df, transform=data_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [33]:
model = UNet2DModel(
    sample_size=32,  
    in_channels=3,   
    out_channels=3,  
    layers_per_block=2,
    block_out_channels=(64, 128, 256, 512),  
    norm_num_groups=8
).to(device)


In [34]:
scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

In [35]:
def train_diffusion(epochs=10):
    model.train()
    for epoch in range(epochs):
        for step, (images, _) in enumerate(dataloader):
            images = images.to(device)
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.shape[0],), device=device).long()
            noise = torch.randn_like(images).to(device)
            noisy_images = scheduler.add_noise(images, noise, timesteps)

            predicted_noise = model(noisy_images, timesteps).sample
            loss = nn.functional.mse_loss(predicted_noise, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 100 == 0:
                print(f"Epoch {epoch+1}, Step {step}, Loss: {loss.item():.4f}")

    print("Training complete.")



In [36]:
def generate_sticker(num_images=1):
    model.eval()
    with torch.no_grad():
        noise = torch.randn((num_images, 3, 64, 64), device=device)
        scheduler.set_timesteps(1000)

        for t in scheduler.timesteps:
            noise = model(noise, t).sample
            noise = scheduler.step(noise, t, noise).prev_sample

        images = (noise.clamp(-1, 1) + 1) / 2  
        images = images.cpu().permute(0, 2, 3, 1).numpy()

        for i, img in enumerate(images):
            img = (img * 255).astype('uint8')
            Image.fromarray(img).save(f"sticker_{i}.png")

    print("Sticker generation complete.")


In [37]:
train_diffusion(epochs=5)  
generate_sticker(num_images=5)


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Epoch 1, Step 0, Loss: 1.0602
Epoch 1, Step 100, Loss: 0.0995
Epoch 1, Step 200, Loss: 0.0573
Epoch 1, Step 300, Loss: 0.1135
Epoch 1, Step 400, Loss: 0.0585
Epoch 1, Step 500, Loss: 0.0507
Epoch 1, Step 600, Loss: 0.0304
Epoch 1, Step 700, Loss: 0.0801


KeyboardInterrupt: 