In [33]:
import json
import random
import boto3
from botocore.exceptions import ClientError
from collections import defaultdict

s3 = boto3.client("s3")

# Add bucket names 
BUCKET_NAME = "lemondataset"                 # Your S3 bucket name
JSON_KEY = "lemon-dataset/lemon-dataset/annotations/instances_default.json"    
IMAGE_PREFIX = "lemon-dataset/lemon-dataset/images/"                      
TRAIN_PREFIX = "lemon-dataset/lemon-dataset/train/"                     
VAL_PREFIX   = "lemon-dataset/lemon-dataset/validation/"                  
TEST_PREFIX  = "lemon-dataset/lemon-dataset/test/"                       


TEST_RATIO = 0.1
VAL_RATIO  = 0.1

annot_json = download_json_from_s3(BUCKET_NAME, JSON_KEY)
with open(annot_json, "r") as f:
    data = json.load(f)

images = data["images"]        
annotations = data["annotations"]

img_to_cats = {}

for ann in annotations:
    img_id = ann["image_id"]
    cat_id = ann["category_id"]

    if img_id not in img_to_cats:
        img_to_cats[img_id] = []

    if cat_id not in img_to_cats[img_id]:
        img_to_cats[img_id].append(cat_id)



In [42]:
# some helper functions I got from GPT 

def download_json_from_s3(bucket, key, local_path="annotations.json"):
    """Download the JSON annotation file from S3."""
    s3.download_file(bucket, key, local_path)
    return local_path

def copy_s3_object(bucket, source_key, dest_key):
    """Copy an object within the same bucket."""
    copy_source = {"Bucket": bucket, "Key": source_key}
    s3.copy_object(Bucket=bucket, CopySource=copy_source, Key=dest_key)

def copy_to_folder(image_ids, target_prefix):
    for i_id in image_ids:
        file_name = id_to_filename[i_id]
        source_key = get_source_key(file_name)

        dest_file = file_name.split("/")[-1]
        dest_key = target_prefix + dest_file
        try:
            #print(f"Copied: {BUCKET_NAME}/{source_key}")
            copy_s3_object(BUCKET_NAME, source_key, dest_key)
        except ClientError as e:
            print(f"Error copying {file_name} to {target_prefix}: {e}")

# This last one just because I've played around with several ways to split and I don't want to duplicate/ fill up the buckets

def delete_all_objects_with_prefix(bucket, prefix):
    paginator = s3.get_paginator('list_objects_v2')
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        if "Contents" in page:
            objects_to_delete = [{'Key': obj['Key']} for obj in page['Contents']]
            s3.delete_objects(Bucket=bucket, Delete={'Objects': objects_to_delete})
            print(f"Deleted {len(objects_to_delete)} objects from {prefix}")


In [43]:


# Delete all objects in train, validation, and test folders
for prefix in [TRAIN_PREFIX, VAL_PREFIX, TEST_PREFIX]:
    delete_all_objects_with_prefix(BUCKET_NAME, prefix)

Deleted 1000 objects from lemon-dataset/lemon-dataset/train/
Deleted 126 objects from lemon-dataset/lemon-dataset/train/
Deleted 269 objects from lemon-dataset/lemon-dataset/validation/
Deleted 269 objects from lemon-dataset/lemon-dataset/test/


In [36]:
category_id_to_name = {cat["id"]: cat["name"] for cat in data["categories"]} # list comprehension making a dict
print("Categories:")
for cat_id, cat_name in category_id_to_name.items():
    print(f"ID {cat_id}: {cat_name}")



Categories:
ID 1: image_quality
ID 2: illness
ID 3: gangrene
ID 4: mould
ID 5: blemish
ID 6: dark_style_remains
ID 7: artifact
ID 8: condition
ID 9: pedicel


In [37]:
cat_to_imgs = {}

for ann in annotations:
    img_id = ann["image_id"]
    cat_id = ann["category_id"]

    if cat_id not in cat_to_imgs:
        cat_to_imgs[cat_id] = []

    if img_id not in cat_to_imgs[cat_id]:
        cat_to_imgs[cat_id].append(img_id)

# Print the total number of unique images per class
for cat_id in cat_to_imgs:
    img_ids = cat_to_imgs[cat_id]
    count = len(img_ids)
    print("Category", cat_id, ":", count, "images")

Category 9 : 1245 images
Category 5 : 2048 images
Category 2 : 1743 images
Category 7 : 451 images
Category 6 : 467 images
Category 3 : 449 images
Category 1 : 5 images
Category 4 : 264 images
Category 8 : 2 images


Originally I was going to treat illness, gangrene and mould as bad categories, and make sure to baalnce them in that way
but because most images have at least some patches of illness we won't have to worry about balancing that class as a random sample will handle it
we'll balance based on gangrene and mould which are more rare.

In [38]:
BAD_CATEGORIES = [ 3, 4]  

In [39]:
# note image ids with bad lemons
bad_image_ids = []

for img in images:
    img_id = img["id"]
    if img_id in img_to_cats:
        for cat in img_to_cats[img_id]:
            if cat in BAD_CATEGORIES:
                bad_image_ids.append(img_id)
                break  # if it has one of the bad categories add it to the list & move on

# map image ID to file name to be able to transfer later
id_to_filename = {}
all_ids = []

for img in images:
    img_id = img["id"]
    id_to_filename[img_id] = img["file_name"]
    all_ids.append(img_id)


# random shuffle the id's - from the shuffle grab a random sample for test & val
random.shuffle(all_ids)

n_test = int(TEST_RATIO * len(all_ids))
n_val = int(VAL_RATIO * len(all_ids))

test_ids = all_ids[:n_test]
val_ids = all_ids[n_test : n_test + n_val]
remain_ids = all_ids[n_test + n_val:]

# of the remaning select only bad lemons
train_bad = []

for img_id in remain_ids:
    if img_id in bad_image_ids:
        train_bad.append(img_id)

# Remove those from remaining so we don't use them again
remain_ids = [img_id for img_id in remain_ids if img_id not in train_bad]

#grab however many good to balance the bad in training
bad_count = len(train_bad)
remain_good = [img_id for img_id in remain_ids if img_id not in bad_image_ids]

# Randomly sample good images, but only up to the number of bad ones
train_good = random.sample(remain_good, min(bad_count, len(remain_good)))
train_ids = train_bad + train_good

print(f"Test: {len(test_ids)} | Val: {len(val_ids)} | Train: {len(train_ids)}")
print(f"  # bad lemons in train: {len(train_bad)}, # good lemons in train: {len(train_good)}")

Test: 269 | Val: 269 | Train: 1126
  # bad lemons in train: 563, # good lemons in train: 563


In [40]:
## had to add this because the json had the images included in the file name which was messsing with the bucket prefixes
# gpt helped me with this
def get_source_key(file_name):

    if file_name.startswith("images/"):
        file_name = file_name[len("images/"):]
    return IMAGE_PREFIX + file_name

In [41]:
print("Copying test images...")
copy_to_folder(test_ids, TEST_PREFIX)

print("Copying validation images...")
copy_to_folder(val_ids, VAL_PREFIX)

print("Copying train images ...")
copy_to_folder(train_ids, TRAIN_PREFIX)

print("Done!")


Copying test images...
Copying validation images...
Copying train images (bad + good)...
Done!
