In [1]:
!pip install pytest pathlib



In [5]:
import unittest
import os
import shutil
from pathlib import Path
import tempfile
from typing import Dict, List

# Original code
def mount_drive():
    # Mock implementation for testing
    print("Drive mounted successfully")

def create_directory_structure(output_dir: str):
    categories = ['VeryRipe', 'Unripe', 'Ripe', 'OverRipe']

    for split in ['valid', 'train', 'test']:
        split_dir = os.path.join(output_dir, split)
        os.makedirs(split_dir, exist_ok=True)
        for category in categories:
            category_dir = os.path.join(split_dir, category)
            os.makedirs(category_dir, exist_ok=True)

def sort_bananas(source_dir: str, output_dir: str):
    # Create directory structure
    create_directory_structure(output_dir)

    # Define category mapping
    category_mapping = {
        'unripe': 'Unripe',
        'ripe': 'Ripe',
        'veryripe': 'VeryRipe',
        'overripe': 'OverRipe'
    }

    # Process each split
    for split in ['valid', 'train', 'test']:
        print(f"\nProcessing {split} split...")
        source_split_dir = os.path.join(source_dir, split)

        if not os.path.exists(source_split_dir):
            print(f"Warning: {source_split_dir} not found")
            continue

        # Process each category
        for src_category in os.listdir(source_split_dir):
            category_path = os.path.join(source_split_dir, src_category)

            if not os.path.isdir(category_path):
                continue

            # Determine destination category from source directory name
            src_category_lower = src_category.lower()
            dest_category = None
            for keyword, category in category_mapping.items():
                if keyword == src_category_lower:
                    dest_category = category
                    break

            # If no match was found based on directory name, use a default
            if dest_category is None:
                dest_category = 'Unripe'  # Default category

            # Process images
            for filename in os.listdir(category_path):
                if not filename.lower().endswith(('.jpg', '.jpeg', '.png')):
                    continue

                # Special case for rotten images regardless of source dir
                file_dest_category = dest_category
                if 'rotten' in filename.lower():
                    file_dest_category = 'OverRipe'

                # Set up paths
                source_path = os.path.join(category_path, filename)
                # Preserve the original case of the filename
                dest_folder = os.path.join(output_dir, split, file_dest_category)
                dest_path = os.path.join(dest_folder, filename)

                # Copy file
                try:
                    shutil.copy2(source_path, dest_path)
                    print(f"Copied {filename} to {file_dest_category}")
                except Exception as e:
                    print(f"Error copying {filename}: {e}")

def print_summary(output_dir: str):
    print("\nSorting Summary:")

    for split in ['valid', 'train', 'test']:
        split_dir = os.path.join(output_dir, split)
        if not os.path.exists(split_dir):
            continue

        print(f"\n{split.upper()}:")
        split_total = 0

        for category in sorted(os.listdir(split_dir)):
            category_path = os.path.join(split_dir, category)
            if os.path.isdir(category_path):
                image_count = len([f for f in os.listdir(category_path)
                                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
                print(f"{category}: {image_count} images")
                split_total += image_count

        print(f"Total {split} images: {split_total}")

class TestBananaClassification(unittest.TestCase):
    def setUp(self):
        """Set up test environment before each test"""
        self.test_dir = tempfile.mkdtemp()
        self.source_dir = os.path.join(self.test_dir, 'source')
        self.output_dir = os.path.join(self.test_dir, 'output')

    def tearDown(self):
        """Clean up test environment after each test"""
        shutil.rmtree(self.test_dir)

    def test_create_directory_structure(self):
        """Test if directory structure is created correctly"""
        create_directory_structure(self.output_dir)

        expected_categories = ['VeryRipe', 'Unripe', 'Ripe', 'OverRipe']

        for split in ['valid', 'train', 'test']:
            for category in expected_categories:
                dir_path = os.path.join(self.output_dir, split, category)
                self.assertTrue(os.path.exists(dir_path),
                              f"Directory not created: {dir_path}")

    def test_sort_bananas_file_categorization(self):
      """Test if files are correctly categorized"""
      # Create source directory structure
      os.makedirs(os.path.join(self.source_dir, 'train', 'unripe'), exist_ok=True)

      # Create test file - note the consistent capitalization
      test_file = os.path.join(self.source_dir, 'train', 'unripe', 'Ripe_test1.jpg')
      Path(test_file).touch()

      # Run sorting
      sort_bananas(self.source_dir, self.output_dir)

      # Check if file was correctly categorized
      expected_path = os.path.join(self.output_dir, 'train', 'Unripe', 'Ripe_test1.jpg')
      self.assertTrue(os.path.exists(expected_path),
                   f"File not found in expected location: {expected_path}")

    def test_print_summary(self):
        """Test if summary printing works correctly"""
        # Create test files
        test_structure = {
            'train': {
                'VeryRipe': 2,
                'Unripe': 3,
                'OverRipe': 1
            },
            'valid': {
                'Ripe': 2,
                'Unripe': 1
            }
        }

        # Create test files
        for split, categories in test_structure.items():
            for category, count in categories.items():
                category_dir = os.path.join(self.output_dir, split, category)
                os.makedirs(category_dir, exist_ok=True)
                for i in range(count):
                    file_path = os.path.join(category_dir, f"test{i}.jpg")
                    Path(file_path).touch()

        # Capture printed output
        import io
        import sys
        captured_output = io.StringIO()
        sys.stdout = captured_output

        print_summary(self.output_dir)

        sys.stdout = sys.__stdout__
        output = captured_output.getvalue()

        # Verify output contains expected information
        self.assertIn("TRAIN:", output)
        self.assertIn("VALID:", output)
        self.assertIn("Total train images: 6", output)
        self.assertIn("Total valid images: 3", output)

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], verbosity=2, exit=False)


test_create_directory_structure (__main__.TestBananaClassification.test_create_directory_structure)
Test if directory structure is created correctly ... ok
test_print_summary (__main__.TestBananaClassification.test_print_summary)
Test if summary printing works correctly ... ok
test_sort_bananas_file_categorization (__main__.TestBananaClassification.test_sort_bananas_file_categorization)
Test if files are correctly categorized ... ok

----------------------------------------------------------------------
Ran 3 tests in 0.014s

OK
