In [1]:
import sys
import re
import json
import shutil


# Third-Party Imports
from PIL import Image
import skimage.io as sk_io
import matplotlib.pyplot as plt
import pylab
from pycocotools.coco import COCO

# Local Imports
from utils.COCO_utils import get_most_common_cat, initialize_coco, process_coco_data, create_dataframe, get_unique_entries, save_dataframe_coco, load_dataframe_coco
from utils.NSD_utils import download_and_process_nsd_data, assign_main_category, save_dataframes_nsd, load_dataframes_nsd

# Configure Matplotlib settings
pylab.rcParams['figure.figsize'] = (8.0, 10.0)

# Ensure correct path for relative imports
sys.path.insert(0, '/Users/davide/Documents/Work/github/model_training')

## Initialize COCO-NSD category mapping

In [21]:
"""
This script initializes and maps COCO dataset category numbers to category names and vice versa.
It also adjusts COCO categories to align with specific needs, like the NSD (Neural Stimulus Decoding) format.
"""

# COCO Category Mapping (Original) - maps category numbers to their respective names in the COCO dataset
coco_num_to_cat_original = {
    1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane',
    6: 'bus', 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light',
    11: 'fire hydrant', 13: 'stop sign', 14: 'parking meter', 15: 'bench',
    16: 'bird', 17: 'cat', 18: 'dog', 19: 'horse', 20: 'sheep',
    21: 'cow', 22: 'elephant', 23: 'bear', 24: 'zebra', 25: 'giraffe',
    27: 'backpack', 28: 'umbrella', 31: 'handbag', 32: 'tie', 33: 'suitcase',
    34: 'frisbee', 35: 'skis', 36: 'snowboard', 37: 'sports ball', 38: 'kite',
    39: 'baseball bat', 40: 'baseball glove', 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket',
    44: 'bottle', 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife',
    50: 'spoon', 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich',
    55: 'orange', 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza',
    60: 'donut', 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant',
    65: 'bed', 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop',
    74: 'mouse', 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave',
    79: 'oven', 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book',
    85: 'clock', 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'
}

# Adjusted COCO Category Mapping for NSD (Maps category numbers from 1 to 80)
# This dictionary remaps COCO categories to fit a specific range for use with the NSD dataset.
sub_number_map = {key: i + 1 for i, key in enumerate(list(coco_num_to_cat_original.keys())[:80])}

# Cleaned COCO Category Mapping (Reduced to 80 categories) - to map category numbers to category names
coco_num_to_cat = {
    1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus', 7: 'train', 
    8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant', 12: 'stop sign', 
    13: 'parking meter', 14: 'bench', 15: 'bird', 16: 'cat', 17: 'dog', 18: 'horse', 
    19: 'sheep', 20: 'cow', 21: 'elephant', 22: 'bear', 23: 'zebra', 24: 'giraffe', 
    25: 'backpack', 26: 'umbrella', 27: 'handbag', 28: 'tie', 29: 'suitcase', 30: 'frisbee', 
    31: 'skis', 32: 'snowboard', 33: 'sports ball', 34: 'kite', 35: 'baseball bat', 
    36: 'baseball glove', 37: 'skateboard', 38: 'surfboard', 39: 'tennis racket', 
    40: 'bottle', 41: 'wine glass', 42: 'cup', 43: 'fork', 44: 'knife', 45: 'spoon', 
    46: 'bowl', 47: 'banana', 48: 'apple', 49: 'sandwich', 50: 'orange', 51: 'broccoli', 
    52: 'carrot', 53: 'hot dog', 54: 'pizza', 55: 'donut', 56: 'cake', 57: 'chair', 
    58: 'couch', 59: 'potted plant', 60: 'bed', 61: 'dining table', 62: 'toilet', 63: 'tv', 
    64: 'laptop', 65: 'mouse', 66: 'remote', 67: 'keyboard', 68: 'cell phone', 
    69: 'microwave', 70: 'oven', 71: 'toaster', 72: 'sink', 73: 'refrigerator', 
    74: 'book', 75: 'clock', 76: 'vase', 77: 'scissors', 78: 'teddy bear', 79: 'hair drier', 
    80: 'toothbrush'
}

# Reverse COCO Category Mapping (Maps category names to category numbers)
coco_cat_to_num = {v: k for k, v in coco_num_to_cat.items()}

## Get COCO info

In [22]:
# Main script settings
data_dir = '/Users/davide/Documents/Work/MS_COCO/Images'
store_path = '/Users/davide/Documents/Work/MS_COCO/data/coco_full.csv'
run_processing = 0
save_to_file = 0
load_from_file = 1

# Execution logic
if run_processing:
    sets = ['val2017', 'train2017']
    coco_cat_n, coco_cat_s, coco_id, coco_captions = [], [], [], []
    
    for dataset in sets:
        coco, coco_caps = initialize_coco(data_dir, dataset)
        cat_n, cat_s, ids, captions = process_coco_data(coco, coco_caps, coco_num_to_cat_original)
        coco_cat_n.extend(cat_n)
        coco_cat_s.extend(cat_s)
        coco_id.extend(ids)
        coco_captions.extend(captions)
    
    coco_pd_full = create_dataframe(coco_cat_n, coco_cat_s, coco_id, coco_captions, sub_number_map)
    coco_pd = get_unique_entries(coco_pd_full)

if save_to_file:
    save_dataframe_coco(coco_pd, store_path)

if load_from_file:
    coco_pd = load_dataframe_coco(store_path)

## Get NSD info

In [23]:
data_path = '/Users/davide/Documents/Work/MS_COCO/data/'
run_processing = 0
save_to_file = 0
load_from_file = 1

# URL for downloading the NSD dataset
nsd_data_url = "https://natural-scenes-dataset.s3.amazonaws.com/nsddata/experiments/nsd/nsd_stim_info_merged.csv"

if run_processing:
    # Step 1: Download and process NSD data
    nsd_allCat = download_and_process_nsd_data(nsd_data_url, coco_pd)
    
    # Step 2: Assign the main category and get single category data
    nsd_allCat, nsd_singleCat = assign_main_category(nsd_allCat)

if save_to_file:
    # Save the processed data to disk
    save_dataframes_nsd(nsd_allCat, nsd_singleCat, data_path)

if load_from_file:
    # Load previously saved data from disk
    nsd_allCat, nsd_singleCat = load_dataframes_nsd(data_path)

## Get multi-labelled images 

In [24]:
"""
This script processes and filters MS COCO dataset annotations to create a subset of data based on the number of exemplars per category.

### Configuration
- `data_path` (str): Directory where data files are stored.
- `run` (int): Flag to indicate whether to run the processing and filtering steps (1 to run, 0 to skip).
- `store` (int): Flag to indicate whether to save the processed data to a file (1 to save, 0 to skip).
- `load` (int): Flag to indicate whether to load previously processed data from a file (1 to load, 0 to skip).

### Parameters
- `n` (int): Minimum number of images required per category to be included in the final dataset.
- `include_persons` (bool): Flag to include the 'person' category (category 1) in the processed data.

### Process Overview
1. **Loading Data**: If the `load` flag is set to 1, it loads the previously saved data from 'nsd_allCat.csv' and 'nsd_singleCat.csv' using `pickle`.

2. **Filtering Categories**:
    - It filters categories in `nsd_singleCat` that have at least `n` exemplars and excludes category 1 (person) if `include_persons` is False.
    - It creates a set of categories to keep based on this filtering.

3. **Subsetting Data**:
    - Removes rows from `nsd_allCat` where 'main_cat' is NaN.
    - Converts 'main_cat' to numerical values.
    - Subsets `nsd_allCat` to include only the rows where 'main_cat' matches the filtered categories.

4. **Processing for Non-Overlapping Categories**:
    - For each row in the subset DataFrame, it identifies categories other than the main category.
    - It determines if these other categories are not in the set of categories to keep.
    - It creates a list of indices where no other categories overlap with the categories to keep.
"""

# Configuration
data_path = '/Users/davide/Documents/Work/MS_COCO/data/'
run = 0
store = 1
load = 0

# Parameters
n = 75  # Minimum number of images per category to keep
include_persons = False  # Flag to include the 'person' category (category 1)

# Load data (assuming data has been loaded previously into nsd_singleCat and nsd_allCat)
if load:
    with open(os.path.join(data_path, 'nsd_allCat.csv'), 'rb') as f:
        nsd_allCat = pickle.load(f)
    with open(os.path.join(data_path, 'nsd_singleCat.csv'), 'rb') as f:
        nsd_singleCat = pickle.load(f)

if run:
    # Filter categories with at least `n` exemplars
    cat_counts = nsd_singleCat['cat_n'].value_counts()
    mask = cat_counts > n
    cat_sub = cat_counts[mask]
    cat_sub = [i[0] for i in cat_sub.index if i[0] != 1]  # Exclude category 1 (person)
    cat_to_keep = set(np.array(cat_sub))
    cat_to_keep_str = [coco_num_to_cat[val] for val in cat_to_keep]

    # Filter out rows where 'main_cat' is NaN
    nsd_allCat_naSub = nsd_allCat[~nsd_allCat['main_cat'].isna()]
    nsd_allCat_naSub['main_cat_n'] = nsd_allCat_naSub['main_cat'].map(coco_cat_to_num)

    # Subset DataFrame to keep only relevant categories
    nsd_allCat_sub = nsd_allCat_naSub[nsd_allCat_naSub['main_cat'].isin(cat_to_keep_str)]

    # Initialize lists
    other_cats_n = []
    non_overlapping_cat_idx = []

    if not include_persons:
        cat_to_keep.add(1)  # Include category 1 if needed

    # Process DataFrame to find non-overlapping categories
    for idx in range(len(nsd_allCat_sub)):
        act_cats = nsd_allCat_sub.iloc[idx]['cat_n']
        act_main_cat = nsd_allCat_sub.iloc[idx]['main_cat_n']
        filtered_list = [elem for elem in act_cats if elem != act_main_cat]
        other_cats_n.append(filtered_list)
        result = all(elem not in cat_to_keep for elem in filtered_list)
        if result:
            non_overlapping_cat_idx.append(idx)

    nsd_allCat_sub['other_cats_n'] = other_cats_n
    nsd_noOverlap = nsd_allCat_sub.iloc[non_overlapping_cat_idx]

if store:
    if include_persons:
        with open(os.path.join(data_path, 'nsd_noOverlap.csv'), 'wb') as f:
            pickle.dump(nsd_noOverlap, f)
    else:
        with open(os.path.join(data_path, 'nsd_noOverlap_noPersons.csv'), 'wb') as f:
            pickle.dump(nsd_noOverlap, f)        

if load:
    if include_persons:
        with open(os.path.join(data_path, 'nsd_noOverlap.csv'), 'rb') as f:
            nsd_noOverlap = pickle.load(f)
    else:    
        with open(os.path.join(data_path, 'nsd_noOverlap_noPersons.csv'), 'rb') as f:
            nsd_noOverlap = pickle.load(f)

## Cropping and storing images 

In [61]:
# Configuration
data_path = '/Users/davide/Documents/Work/MS_COCO/data/'
which_dataset = 'nsd_singleCat'  # Options: 'nsd_singleCat', 'nsd_noOverlap', 'nsd_noOverlap_noPersons'

# Load data based on the dataset selection
datasets = {
    'nsd_singleCat': 'nsd_singleCat.csv',
    'nsd_noOverlap': 'nsd_noOverlap.csv',
    'nsd_noOverlap_noPersons': 'nsd_noOverlap_noPersons.csv'
}

dataset_file = datasets.get(which_dataset)
with open(os.path.join(data_path, dataset_file), 'rb') as f:
    pd_to_use = pickle.load(f)

# Define directories
directories = {
    'nsd_singleCat': '/Users/davide/Documents/Work/MS_COCO/nsd_images/single_cat/',
    'nsd_noOverlap': '/Users/davide/Documents/Work/MS_COCO/nsd_images/multiple_cats/',
    'nsd_noOverlap_noPersons': '/Users/davide/Documents/Work/MS_COCO/nsd_images/multiple_cats_no_persons/'
}

cropped_images_path = directories.get(which_dataset)

# Flags
make_cat_dir = 1
run_crop = 1

# Create directories if needed
def create_dirs(directory, categories):
    for category in categories:
        cat_path = os.path.join(directory, category)
        if not os.path.isdir(cat_path):
            os.makedirs(cat_path)

if make_cat_dir:
    categories = [cat for cat in pd_to_use['main_cat'].unique() if not pd.isna(cat)]
    create_dirs(cropped_images_path, categories)

# Prepare paths and move images if write flag is set

source_paths = []
destination_paths = []
nsd_pd = pd_to_use[pd.notna(pd_to_use['main_cat'])].reset_index(drop=True)
coco_imgs_folder = '/Users/davide/Documents/Work/MS_COCO/Images/trainVal2017/'

for img in range(len(nsd_pd)):
    if img % 5000 == 0: print(f'Adding path {img}')
    act_id = nsd_pd.loc[img, 'id']
    act_nsd_id = nsd_pd.loc[img, 'nsdId']
    act_cat = nsd_pd.loc[img, 'main_cat']
    act_img_name = '0' * (12 - len(str(act_id))) + str(act_id) + '.jpg'
    act_nsd_name = str(act_nsd_id) + '.jpg'
    source_path = os.path.join(coco_imgs_folder, act_img_name)
    destination_path = os.path.join(cropped_images_path, act_cat, act_nsd_name)
    source_paths.append(source_path)
    destination_paths.append(destination_path)

nsd_pd['source_paths'] = source_paths
nsd_pd['destination_paths'] = destination_paths

print('\n\nStart cropping and saving images')

# Crop and resize images if run_crop flag is set
if run_crop:
    resize_dim = (425, 425)

    for i in range(len(nsd_pd)):
        if i % 100 == 0: print(f'Processing image {i}')

        source_path = nsd_pd.loc[i]['source_paths']
        destination_path = nsd_pd.loc[i]['destination_paths']
        act_crop_box = nsd_pd.loc[i]['cropBox']
        
        # Convert crop box into a tuple
        crop_box_str = act_crop_box.strip().strip('()')
        act_crop_box_t = tuple(map(float, crop_box_str.split(',')))

        if not os.path.isfile(source_path):
            print(f'Warning: File not found {source_path}') 
            continue

        with Image.open(source_path) as img:
            w, h = img.size
            
            # Unpack crop_info 
            top_pct, bottom_pct, left_pct, right_pct = act_crop_box_t

            # Calculate pixel values for the crop box
            left = int(left_pct * w)
            top = int(top_pct * h)
            right = int((1 - right_pct) * w)
            bottom = int((1 - bottom_pct) * h)  

            # Crop the image
            cropped_img = img.crop((left, top, right, bottom))
            resized_img = cropped_img.resize(resize_dim, Image.Resampling.LANCZOS)
            
            # Save the cropped and resized image
            resized_img.save(destination_path)

    print(f'Finished processing images for dataset: {which_dataset}')

Adding path 0
Adding path 5000
Adding path 10000
Adding path 15000
Start cropping and saving images
Processing image 0
Processing image 100
Processing image 200
Processing image 300
Processing image 400
Processing image 500
Processing image 600
Processing image 700
Processing image 800
Processing image 900
Processing image 1000
Processing image 1100
Processing image 1200
Processing image 1300
Processing image 1400
Processing image 1500
Processing image 1600
Processing image 1700
Processing image 1800
Processing image 1900
Processing image 2000
Processing image 2100
Processing image 2200
Processing image 2300
Processing image 2400
Processing image 2500
Processing image 2600
Processing image 2700
Processing image 2800
Processing image 2900
Processing image 3000
Processing image 3100
Processing image 3200
Processing image 3300
Processing image 3400
Processing image 3500
Processing image 3600
Processing image 3700
Processing image 3800
Processing image 3900
Processing image 4000
Processing