# 📚 Kanji Dataset Processing Components

**Missing components for Kanji generation task**  
This notebook provides the complete dataset processing pipeline for real Kanji generation

In [None]:
# Install required dependencies
!pip install requests svgpathtools cairosvg pillow tqdm

In [None]:
# Kanji Dataset Processing Components
# This code provides the missing components for real Kanji generation

import xml.etree.ElementTree as ET
import requests
import svgpathtools
from cairosvg import svg2png
from PIL import Image
import io
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
import json
from tqdm import tqdm

class KanjiDataProcessor:
    """
    Process KANJIDIC2 and KanjiVG data for Stable Diffusion training
    """
    
    def __init__(self, data_dir="kanji_data"):
        self.data_dir = data_dir
        os.makedirs(data_dir, exist_ok=True)
        
    def download_kanjidic2(self):
        """Download KANJIDIC2 XML file"""
        url = "http://www.edrdg.org/kanjidic/kanjidic2.xml.gz"
        print("📥 Downloading KANJIDIC2...")
        
        import gzip
        response = requests.get(url)
        
        # Decompress and save
        xml_path = os.path.join(self.data_dir, "kanjidic2.xml")
        with gzip.open(io.BytesIO(response.content), 'rb') as gz:
            with open(xml_path, 'wb') as f:
                f.write(gz.read())
        
        print(f"✅ KANJIDIC2 saved to {xml_path}")
        return xml_path
    
    def parse_kanjidic2(self, xml_path):
        """Parse KANJIDIC2 XML to extract kanji and English meanings"""
        print("📖 Parsing KANJIDIC2...")
        
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        kanji_data = {}
        
        for character in root.findall('.//character'):
            # Get the kanji character
            literal = character.find('.//literal')
            if literal is None:
                continue
            
            kanji = literal.text
            
            # Get English meanings
            meanings = []
            reading_meaning = character.find('.//reading_meaning')
            if reading_meaning:
                for meaning in reading_meaning.findall('.//meaning'):
                    # Only get meanings without language attribute (English)
                    if 'm_lang' not in meaning.attrib:
                        meanings.append(meaning.text)
            
            if meanings:
                # Create a short definition from meanings
                definition = ", ".join(meanings[:3])  # Use first 3 meanings
                kanji_data[kanji] = {
                    'character': kanji,
                    'definition': definition,
                    'all_meanings': meanings
                }
        
        print(f"✅ Parsed {len(kanji_data)} kanji characters")
        return kanji_data
    
    def download_kanjivg(self, kanji_list):
        """Download KanjiVG SVG files for specific kanji"""
        base_url = "https://raw.githubusercontent.com/KanjiVG/kanjivg/master/kanji/"
        svg_dir = os.path.join(self.data_dir, "svg")
        os.makedirs(svg_dir, exist_ok=True)
        
        print("📥 Downloading KanjiVG SVG files...")
        
        downloaded = []
        for kanji in tqdm(kanji_list):
            # Convert kanji to Unicode hex
            unicode_hex = format(ord(kanji), '05x')
            svg_filename = f"{unicode_hex}.svg"
            svg_url = base_url + svg_filename
            svg_path = os.path.join(svg_dir, svg_filename)
            
            try:
                if not os.path.exists(svg_path):
                    response = requests.get(svg_url)
                    if response.status_code == 200:
                        with open(svg_path, 'w', encoding='utf-8') as f:
                            f.write(response.text)
                        downloaded.append((kanji, svg_path))
                else:
                    downloaded.append((kanji, svg_path))
            except Exception as e:
                print(f"Failed to download {kanji}: {e}")
        
        print(f"✅ Downloaded {len(downloaded)} SVG files")
        return downloaded
    
    def process_svg_to_image(self, svg_path, size=128):
        """
        Convert SVG to pixel image with pure black strokes
        Remove stroke order numbers
        """
        with open(svg_path, 'r', encoding='utf-8') as f:
            svg_content = f.read()
        
        # Remove stroke order numbers (text elements)
        import re
        # Remove text elements that contain stroke numbers
        svg_content = re.sub(r'<text[^>]*>.*?</text>', '', svg_content, flags=re.DOTALL)
        
        # Ensure pure black strokes
        svg_content = svg_content.replace('stroke="red"', 'stroke="#000000"')
        svg_content = svg_content.replace('stroke="orange"', 'stroke="#000000"')
        
        # Add explicit black stroke if not present
        if 'stroke=' not in svg_content:
            svg_content = svg_content.replace('<path', '<path stroke="#000000" fill="none" stroke-width="3"')
        
        # Convert to PNG
        png_data = svg2png(
            bytestring=svg_content.encode('utf-8'),
            output_width=size,
            output_height=size,
            background_color='white'
        )
        
        # Convert to PIL Image
        image = Image.open(io.BytesIO(png_data))
        
        # Ensure it's pure black and white
        image = image.convert('L')  # Convert to grayscale
        image_array = np.array(image)
        
        # Threshold to ensure pure black strokes
        threshold = 128
        image_array = np.where(image_array < threshold, 0, 255).astype(np.uint8)
        
        # Convert back to PIL Image
        image = Image.fromarray(image_array)
        
        # Convert to RGB (required for Stable Diffusion)
        image = image.convert('RGB')
        
        return image
    
    def create_dataset(self, num_samples=None):
        """Create the complete dataset"""
        print("🏗️ Creating Kanji dataset...")
        
        # Download and parse KANJIDIC2
        if not os.path.exists(os.path.join(self.data_dir, "kanjidic2.xml")):
            xml_path = self.download_kanjidic2()
        else:
            xml_path = os.path.join(self.data_dir, "kanjidic2.xml")
        
        kanji_data = self.parse_kanjidic2(xml_path)
        
        # Limit samples if specified
        if num_samples:
            kanji_list = list(kanji_data.keys())[:num_samples]
        else:
            kanji_list = list(kanji_data.keys())
        
        # Download SVG files
        svg_files = self.download_kanjivg(kanji_list)
        
        # Process images
        dataset = []
        print("🎨 Converting SVG to images...")
        
        for kanji, svg_path in tqdm(svg_files):
            try:
                # Convert SVG to image
                image = self.process_svg_to_image(svg_path, size=128)
                
                # Get definition
                definition = kanji_data[kanji]['definition']
                
                                # Add to dataset
                dataset.append({
                    'kanji': kanji,
                    'definition': definition,
                    'image': image,
                    'meanings': kanji_data[kanji]['all_meanings']
                })
                
            except Exception as e:
                print(f"Failed to process {kanji}: {e}")
        
        print(f"✅ Created dataset with {len(dataset)} entries")
        
        # Save dataset metadata
        metadata_path = os.path.join(self.data_dir, "dataset_metadata.json")
        metadata = [
            {
                'kanji': item['kanji'],
                'definition': item['definition'],
                'meanings': item['meanings']
            }
            for item in dataset
        ]
        
        with open(metadata_path, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, ensure_ascii=False, indent=2)
        
        print(f"💾 Metadata saved to {metadata_path}")
        
        return dataset


class KanjiDataset(Dataset):
    """PyTorch Dataset for Kanji-Definition pairs"""
    
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image']
        definition = item['definition']
        
        # Convert image to tensor
        if self.transform:
            image = self.transform(image)
        else:
            # Default transformation
            image = np.array(image).astype(np.float32) / 255.0
            image = (image - 0.5) * 2  # Normalize to [-1, 1]
            image = torch.from_numpy(image).permute(2, 0, 1)
        
        return {
            'image': image,
            'text': definition,
            'kanji': item['kanji']
        }


def integrate_with_trainer(trainer, dataset):
    """
    Integrate real Kanji dataset with the existing trainer
    Replace the synthetic dataset with real data
    """
    print("🔄 Integrating real Kanji dataset with trainer...")
    
    # Create PyTorch dataset
    kanji_dataset = KanjiDataset(dataset)
    
    # Create dataloader
    dataloader = DataLoader(
        kanji_dataset,
        batch_size=trainer.batch_size,
        shuffle=True,
        num_workers=2
    )
    
    # Modify trainer to use real dataset
    trainer.kanji_dataloader = dataloader
    trainer.kanji_dataset = dataset
    
    print("✅ Dataset integrated successfully")
    return dataloader

## 🧪 Test the Dataset Processing

In [None]:
# Test the dataset processing
if __name__ == "__main__":
    # Initialize data processor
    processor = KanjiDataProcessor()
    
    # Create dataset (limit to 100 for testing)
    dataset = processor.create_dataset(num_samples=100)
    
    # Show sample
    if dataset:
        sample = dataset[0]
        print(f"\n📊 Sample entry:")
        print(f"  Kanji: {sample['kanji']}")
        print(f"  Definition: {sample['definition']}")
        print(f"  Image size: {sample['image'].size}")
        
        # Display the image
        import matplotlib.pyplot as plt
        plt.figure(figsize=(6, 6))
        plt.imshow(sample['image'], cmap='gray')
        plt.title(f'Kanji: {sample["kanji"]} - {sample["definition"]}', fontsize=12)
        plt.axis('off')
        plt.show()
    
    print(f"\n✅ Dataset processing test completed!")
        print(f"   • Total entries: {len(dataset)}")
        print(f"   • Data directory: {processor.data_dir}")
        print(f"\n💡 To integrate with existing trainer:")
        print(f"   trainer = ColabOptimizedTrainer()")
        print(f"   dataloader = integrate_with_trainer(trainer, dataset)")
        print(f"   trainer.train()  # Now uses real Kanji data")