In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from diffusers import UNet2DModel, DDPMScheduler
from PIL import Image
import pandas as pd
import os

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [15]:
df = pd.read_parquet("../data/processed_sticker_dataset.parquet")
df.head()

Unnamed: 0,combined_embedding,image_path
0,"[0.05615041, 0.06784809, -0.03342954, 0.037553...",../data/tensor_images/AlexatorStickers\cartoon...
1,"[-0.124234326, 0.07463956, -0.011985385, 0.004...",../data/tensor_images/AlexatorStickers\cartoon...
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",../data/tensor_images/AlexatorStickers\cartoon...
3,"[-0.06495428, -0.04292713, 0.013164402, 0.0220...",../data/tensor_images/AlexatorStickers\cartoon...
4,"[0.027918205, 0.075559475, 0.03622711, -0.0181...",../data/tensor_images/AlexatorStickers\cartoon...


In [None]:

# Define transformation
data_transform = transforms.Compose([
    transforms.Resize((64, 64)),  #resize images
    transforms.ToTensor(),    #converting images into PyTorch
    transforms.Normalize([0.5], [0.5])  #normalizing pixel values
])

class StickerDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df        #storing the dataframe containing image
        self.transform = transform  #storing transformation

    def __len__(self):      
        return len(self.df)     #returns the total number of images in dataset

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]["image_path"]  #fetching the row at index idx from dataframe df and retrieving the file path of that image
        image = Image.open(img_path).convert("RGB")  #opening the image file using open method from pillow library
        if self.transform:
            image = self.transform(image)
        return image, 0  # Dummy label

# Create DataLoader
dataset = StickerDataset(df, transform=data_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)