# Style Transfer Training

This notebook will be used to train the style transfer model.

## 1. Setup

### 1.1. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### 1.2. Download Dataset

Make sure your dataset is accessible via a shareable link from Google Drive.

In [None]:
!pip install gdown

import gdown

# Replace with the actual file ID of your dataset from Google Drive
file_id = '1acx_7IdpUg3OBXSRfnNyyjiZlIiQCE8M' 
output = 'dataset.zip'

gdown.download(id=file_id, output=output, quiet=False)

# Unzip the dataset if it's a zip file
!unzip -q dataset.zip -d /content/dataset

## 2. Prepare for Training

### 2.1. Install Dependencies

In [None]:
!pip install -q diffusers transformers accelerate peft

### 2.2. Define Training Parameters

In [None]:
# Base model from Hugging Face
MODEL_NAME = "runwayml/stable-diffusion-v1-5"

# Directory with your training images
INSTANCE_DIR = "/content/dataset"

# Directory to save the trained model
OUTPUT_DIR = "/content/models/style-transfer-model"

# The unique prompt that will trigger your style
# "sks" is a placeholder for a unique token, which is a common practice.
INSTANCE_PROMPT = "a drawing in sks style"

### 2.3. Preprocess Images and Create Captions

This cell will prepare your dataset for training. It does two things:
1. Resizes all your images to 512x512 pixels, which is the standard for Stable Diffusion 1.5.
2. Creates a text file for each image with your unique prompt (`INSTANCE_PROMPT`).

In [None]:
import os
from PIL import Image

# Create the output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

image_files = [f for f in os.listdir(INSTANCE_DIR) if f.endswith((".jpg", ".jpeg", ".png"))]

for filename in image_files:
    # Create caption file
    caption_filename = os.path.splitext(filename)[0] + ".txt"
    with open(os.path.join(INSTANCE_DIR, caption_filename), "w") as f:
        f.write(INSTANCE_PROMPT)

    # Resize image
    image_path = os.path.join(INSTANCE_DIR, filename)
    try:
        with Image.open(image_path) as img:
            # Ensure image is in RGB mode
            img = img.convert("RGB")
            img = img.resize((512, 512), Image.LANCZOS)
            img.save(image_path)
    except Exception as e:
        print(f"Could not process {filename}: {e}")

print(f"Processed {len(image_files)} images and created caption files.")