In [None]:
!pip install opencv-python numpy matplotlib scikit-learn tensorflow

In [None]:
import cv2
import numpy as np
from pathlib import Path
import random
import matplotlib
matplotlib.use('TkAgg')    
import matplotlib.pyplot as plt
try:
    import IPython
    shell = IPython.get_ipython()
    # shell.enable_matplotlib(gui='qt')
except:
    pass 

class DataAugmentor:
    """Generates 6 variations of a single character"""
    
    @staticmethod
    def rotate(img, angle):
        h, w = img.shape
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        return cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=0)

    @staticmethod
    def shift(img, dx, dy):
        h, w = img.shape
        M = np.float32([[1, 0, dx], [0, 1, dy]])
        return cv2.warpAffine(img, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT, borderValue=0)

    @staticmethod
    def zoom(img, scale):
        h, w = img.shape
        center_x, center_y = w // 2, h // 2
        radius_x, radius_y = w // (2 * scale), h // (2 * scale)
        min_x, max_x = int(center_x - radius_x), int(center_x + radius_x)
        min_y, max_y = int(center_y - radius_y), int(center_y + radius_y)
        cropped = img[min_y:max_y, min_x:max_x]
        if cropped.size == 0: return img 
        return cv2.resize(cropped, (w, h), interpolation=cv2.INTER_LINEAR)

    @staticmethod
    def thicken(img):
        kernel = np.ones((2,2), np.uint8)
        return cv2.dilate(img, kernel, iterations=1)

    @staticmethod
    def thin(img):
        kernel = np.ones((2,2), np.uint8)
        return cv2.erode(img, kernel, iterations=1)

class MatplotlibGridExtractor:
    def __init__(self, output_dir='dataset/train_augmented'):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.num_rows = 10 
        self.num_cols = 7   
        self.family_names = ['ha', 'le', 'hha', 'me', 'se', 're', 'sa', 'sha', 'qe', 'be']
    
    def augment_cell(self, img):
        """Creates 6x dataset: 1 Original + 5 Variations"""
        if img is None or cv2.countNonZero(img) == 0: return {}
        variations = {}
        
        # 1. Original
        variations['orig'] = img
        # 2. Rotation
        angle = random.uniform(5, 10) * (1 if random.random() > 0.5 else -1)
        variations['rot'] = DataAugmentor.rotate(img, angle)
        # 3. Zoom
        variations['zoom'] = DataAugmentor.zoom(img, random.uniform(1.1, 1.2))
        # 4. Shift
        dx, dy = random.randint(-3, 3), random.randint(-3, 3)
        variations['shift'] = DataAugmentor.shift(img, dx, dy)
        # 5. Bold
        variations['bold'] = DataAugmentor.thicken(img)
        # 6. Thin (fallback to rotation if image too thin)
        thinned = DataAugmentor.thin(img)
        variations['thin'] = thinned if cv2.countNonZero(thinned) > cv2.countNonZero(img) * 0.4 else DataAugmentor.rotate(img, -3)

        return variations

    def manual_calibrate(self, image_path):
        """
        Uses MATPLOTLIB for the GUI to avoid OpenCV errors.
        """
        img = cv2.imread(str(image_path))
        if img is None: return None, None

        # Convert to RGB for Matplotlib
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        print(f"\n--- CALIBRATING: {image_path.name} ---")
        print("INSTRUCTIONS:")
        print("1. A window will pop up (check your taskbar if you don't see it).")
        print("2. Click the 4 corners: Top-Left -> Top-Right -> Bottom-Right -> Bottom-Left.")
        print("3. Close the window after clicking 4 times.")

        # Create Plot Window
        fig = plt.figure(figsize=(10, 10))
        plt.imshow(img_rgb)
        plt.title("Click: TL -> TR -> BR -> BL")
        plt.axis('off')
        
        # Get 4 inputs from user
        # timeout=-1 means wait forever until clicks happen
        pts = plt.ginput(n=4, timeout=-1, show_clicks=True)
        plt.close(fig)
        
        if len(pts) == 4:
            # Convert float coordinates to integers
            points = [(int(x), int(y)) for x, y in pts]
            return points, img
        else:
            print("Did not record 4 clicks. Skipping.")
            return None, None
    
    def extract_and_augment(self, image_path, sample_id):
        calib_points, img = self.manual_calibrate(image_path)
        if calib_points is None: return

        # 1. Perspective Transform
        pts1 = np.float32(calib_points)
        widthA = np.sqrt(((pts1[1][0] - pts1[0][0]) ** 2) + ((pts1[1][1] - pts1[0][1]) ** 2))
        heightA = np.sqrt(((pts1[3][0] - pts1[0][0]) ** 2) + ((pts1[3][1] - pts1[0][1]) ** 2))
        maxWidth, maxHeight = int(widthA), int(heightA)
        
        pts2 = np.float32([[0, 0], [maxWidth, 0], [maxWidth, maxHeight], [0, maxHeight]])
        matrix = cv2.getPerspectiveTransform(pts1, pts2)
        
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        warped_grid = cv2.warpPerspective(gray, matrix, (maxWidth, maxHeight))
        
        # 2. Slice Grid
        cell_w = maxWidth / self.num_cols
        cell_h = maxHeight / self.num_rows
        
        total_saved = 0
        for row in range(self.num_rows):
            family = self.family_names[row] # e.g., 'ha'
            for col in range(self.num_cols):
                form_num = col + 1
                # The filename will track which form it is, but the folder will be the family
                char_id = f"form{form_num}" 
                
                # Coords with Margin
                x1 = int(col * cell_w + cell_w * 0.15)
                y1 = int(row * cell_h + cell_h * 0.15)
                x2 = int((col + 1) * cell_w - cell_w * 0.15)
                y2 = int((row + 1) * cell_h - cell_h * 0.15)
                
                cell_img = warped_grid[y1:y2, x1:x2]
                
                if cell_img.size == 0: continue
                filtered = cv2.bilateralFilter(cell_img, 9, 75, 75)
                _, binary = cv2.threshold(filtered, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
                
                if cv2.countNonZero(binary) > 50:
                    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                    if contours:
                        c = max(contours, key=cv2.contourArea)
                        x, y, w, h = cv2.boundingRect(c)
                        roi = binary[y:y+h, x:x+w]
                        
                        target = 64
                        scale = min(48/w, 48/h)
                        nw, nh = int(w*scale), int(h*scale)
                        resized = cv2.resize(roi, (nw, nh))
                        final = np.zeros((target, target), dtype=np.uint8)
                        px, py = (target-nw)//2, (target-nh)//2
                        final[py:py+nh, px:px+nw] = resized
                        
                        # Generate 6 versions
                        aug_batch = self.augment_cell(final)
                        
                        # CHANGE: Save into the family folder, passing the form info for the filename
                        self.save_batch(aug_batch, family, char_id, sample_id)
                        total_saved += len(aug_batch)

        print(f"  âœ“ Saved {total_saved} images to family folders from {image_path.name}")
    def augment_cell(self, img):
        """Wrapper for data augmentor"""
        augmentor = DataAugmentor()
        # We instantiate it here or make methods static (I made them static above)
        return MatplotlibGridExtractor.static_augment(img)

    @staticmethod
    def static_augment(img):
        # Re-using the logic from previous step, but ensuring it's accessible
        variations = {'orig': img}
        angle = random.uniform(5, 10) * (1 if random.random() > 0.5 else -1)
        variations['rot'] = DataAugmentor.rotate(img, angle)
        variations['zoom'] = DataAugmentor.zoom(img, random.uniform(1.1, 1.2))
        variations['shift'] = DataAugmentor.shift(img, random.randint(-3, 3), random.randint(-3, 3))
        variations['bold'] = DataAugmentor.thicken(img)
        thinned = DataAugmentor.thin(img)
        variations['thin'] = thinned if cv2.countNonZero(thinned) > cv2.countNonZero(img) * 0.4 else DataAugmentor.rotate(img, -3)
        return variations

    def save_batch(self, batch, family_name, char_id, sample_id):
        family_dir = self.output_dir / family_name
        family_dir.mkdir(parents=True, exist_ok=True)
        
        
        unique_id = random.randint(1000, 9999)
        
        for type_name, img in batch.items():
            file_name = f"{sample_id}_{char_id}_{unique_id}_{type_name}.png"
            cv2.imwrite(str(family_dir / file_name), img)

    def batch_process(self, input_dir):
        input_path = Path(input_dir)
        files = sorted([f for f in input_path.glob('*') if f.suffix.lower() in ['.jpg', '.png']])
        print(f"Found {len(files)} images.")
        
        for i, f in enumerate(files):
            print(f"Processing [{i+1}/{len(files)}]: {f.name}")
            self.extract_and_augment(f, f.stem)

if __name__ == "__main__":
    extractor = MatplotlibGridExtractor(output_dir='dataset')
    extractor.batch_process('scanned_sheets')