1. Dataset Handling


You have images for 8 classes: Rook, Bishop, King,
Knight, Pawn, and Queen.

Goal: Ensure each class has enough samples for effective generative model training and validation.

2. Data Preparation Steps

In [76]:
#Step 1: Resize Images
from PIL import Image
import os

def is_image_file(file_name):
    valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff',".webp"]
    return os.path.splitext(file_name)[1].lower() in valid_extensions

def resize_images(input_folders, output_folder, size=(256, 256)):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    for input_folder in input_folders:
        class_name = os.path.basename(input_folder)
        class_output_folder = os.path.join(output_folder, class_name)
        
        if not os.path.exists(class_output_folder):
            os.makedirs(class_output_folder)
        
        for img_name in os.listdir(input_folder):
            if not is_image_file(img_name):
                print(f"Skipping non-image file: {img_name}")
                continue
            
            img_path = os.path.join(input_folder, img_name)
            
            try:
                img = Image.open(img_path)
                img = img.resize(size, Image.Resampling.LANCZOS)  # Use LANCZOS for high-quality resizing
                img.save(os.path.join(class_output_folder, img_name))
            except Exception as e:
                print(f"Error processing {img_name}: {e}")

# Example usage
input_folders = [
    r"C:\Users\Keshavi\Downloads\Chess\Queen",
    r"C:\Users\Keshavi\Downloads\Chess\Rook",
    r"C:\Users\Keshavi\Downloads\Chess\Bishop",
    r"C:\Users\Keshavi\Downloads\Chess\Knight",
    r"C:\Users\Keshavi\Downloads\Chess\Pawn",
    r"C:\Users\Keshavi\Downloads\Chess\King"
]

resized_chess_pieces = r'C:\\Users\\Keshavi\\Downloads\\resized_chess_pieces'

resize_images(input_folders, resized_chess_pieces, size=(256, 256)) 


In [77]:
#Step 2: Normalize Image Pixel Values
import torchvision.transforms as transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # Normalize to [-1, 1]
])

def load_and_normalize_image(image_path):
    image = Image.open(image_path)
    image = transform(image)
    return image


In [78]:
import os
from sklearn.model_selection import train_test_split
import shutil




In [83]:
import os
from sklearn.model_selection import train_test_split
import shutil

def split_dataset(input_folders, output_folder, test_size=0.2):
    os.makedirs(output_folder, exist_ok=True)
    
    for input_folder in input_folders:
        # Create train and validation folders for each class
        class_name = os.path.basename(input_folder)  # Get class name from folder name
        train_folder = os.path.join(output_folder, 'train', class_name)
        val_folder = os.path.join(output_folder, 'val', class_name)
        
        os.makedirs(train_folder, exist_ok=True)
        os.makedirs(val_folder, exist_ok=True)
        
        # Get all image file paths
        images = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', ".webp"))]
        
        # Split the images
        train_images, val_images = train_test_split(images, test_size=test_size, random_state=42)

        # Move images to the corresponding folders
        for image in train_images:
            shutil.move(image, os.path.join(train_folder, os.path.basename(image)))
        for image in val_images:
            shutil.move(image, os.path.join(val_folder, os.path.basename(image)))

# Example usage
resized_chess_pieces = [
    r"C:\Users\Keshavi\Downloads\resized_chess_pieces\King",
    r"C:\Users\Keshavi\Downloads\resized_chess_pieces\Pawn",
    r"C:\Users\Keshavi\Downloads\resized_chess_pieces\Knight",
    r"C:\Users\Keshavi\Downloads\resized_chess_pieces\Bishop",
    r"C:\Users\Keshavi\Downloads\resized_chess_pieces\Rook",
    r"C:\Users\Keshavi\Downloads\resized_chess_pieces\Queen"
]

chess_pieces_split = r"C:\Users\Keshavi\Downloads\chess_pieces_split"
split_dataset(resized_chess_pieces, chess_pieces_split, test_size=0.2)


In [102]:
import diffusers
print(diffusers.__version__)


  from .autonotebook import tqdm as notebook_tqdm


0.30.3


In [5]:
#Data Augmentation Using Generative Models
#1. Stable Diffusion

# Import necessary libraries
import os
from diffusers import StableDiffusionPipeline

def generate_images_stable_diffusion(prompt, output_folder, num_images=5):
    # Load the Stable Diffusion model
    pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pipe = pipe.to(device)


    for i in range(num_images):
        image = pipe(prompt).images[0]
        # Save the generated image to the output folder
        image.save(f"{output_folder}/{prompt.replace(' ', '_')}_{i}.png")

# Define the output folder for generated images
output_folder = r"C:\Users\Keshavi\Downloads\generated_chess_images"

# Ensure the output folder exists
os.makedirs(output_folder, exist_ok=True)

# List of chess pieces to generate images for
chess_pieces = ["Rook", "Bishop", "Knight", "Pawn", "Queen", "King"]

# Generate images for each chess piece
for piece in chess_pieces:
    generate_images_stable_diffusion(f"A chess {piece}", output_folder, num_images=5)





Couldn't connect to the Hub: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/models/CompVis/stable-diffusion-v1-4 (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: EE certificate key too weak (_ssl.c:1000)')))"), '(Request ID: 85788cbe-7d18-4f98-947f-0b7cb40d0127)').
Will try to load from local cache.
Loading pipeline components...: 100%|██████████| 7/7 [00:03<00:00,  1.84it/s]
100%|██████████| 50/50 [1:06:30<00:00, 79.80s/it] 
  6%|▌         | 3/50 [03:42<52:51, 67.47s/it]   