In [1]:
import csv
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

### Convert the csv file to a dictionary of standard data format of MMLab

In [None]:
def csv_to_dict(src_dir, in_csv_path, out_json_path):
    
    result_dict = {"metainfo": {}, "data_list": []}
    
    with open(in_csv_path, 'r') as csvfile:
        reader = csv.DictReader(csvfile)
        
        # Extract column names to form classes, skipping '', 'center_lon', 'center_lat', and 'chip_id'
        result_dict["metainfo"]["classes"] = [field for field in reader.fieldnames if field not in ["", "center_lon", "center_lat", "chip_id"]]
        
        # Convert each row to the required dictionary format
        for row in reader:
            data_item = {}
            
            # Construct image path
            data_item["img_path"] = f"{src_dir}/chip_{row['chip_id']}.tif"
            
            # Construct ground truth labels
            gt_labels = [i for i, class_name in enumerate(result_dict["metainfo"]["classes"]) if int(row[class_name]) == 1]
            data_item["gt_label"] = gt_labels
            
            # Append to data list
            result_dict["data_list"].append(data_item)
            
    # Save the result dictionary as a JSON file
    with open(out_json_path, "w") as jsonfile:
        json.dump(result_dict, jsonfile, indent=4)


In [None]:
# Test the function
src_dir = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/chips"
in_csv_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/CEO_plot_id.csv"  
out_json_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/dataset_dict.json"
csv_to_dict(src_dir, in_csv_path, out_json_path)

In [None]:
class CSVConverter:
    
    def __init__(self, src_dir, in_csv_path):
        self.src_dir = src_dir
        self.in_csv_path = in_csv_path
        self.data = self._parse_csv()
        
    def _parse_csv(self):
        with open(self.in_csv_path, "r") as f:
            reader = csv.reader(f)
            rows = list(reader)
        
        header = rows[0]
        classes = header[3:-1]
        
        data_list = []
        for row in rows[1:]:
            img_path = f"{self.src_dir}/chip_{row[-1]}.tif"
            gt_label = [i for i, value in enumerate(row[3:-1]) if int(value) == 1]
            data_list.append({
                "img_path": img_path,
                "gt_label": gt_label
            })
        
        return {
            "metainfo": {
                "classes": classes
            },
            "data_list": data_list
        }
    
    def get_data(self):
        return self.data

In [None]:
src_dir = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/chips"
in_csv_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/CEO_plot_id.csv" 

In [None]:
converter = CSVConverter(src_dir, in_csv_path)
print(converter.get_data())

### Custom Dataset

In [2]:
from mmpretrain.registry import DATASETS
from mmpretrain.datasets.base_dataset import BaseDataset
from typing import List, Any, Union, Optional

#@DATASETS.register_module()
class MultiLabelDataset(BaseDataset):
    """Multi-label Dataset for image classification.

    This dataset extends BaseDataset to support multi-label classification.

    Args:
        ann_file (str): Annotation file path.
        metainfo (dict, optional): Meta information for dataset, such as class information.
        ... (Other arguments inherited from BaseDataset)

    """

    def __init__(self, 
                 ann_file: str,
                 metainfo: Optional[dict] = None,
                 **kwargs: Any):
        # Custom checks or operations for ann_file can go here
        if not ann_file.endswith('.json'):
            raise ValueError("Annotation file must be a .json file")
        
        # Call the parent class's init method
        super().__init__(ann_file=ann_file, metainfo=metainfo, **kwargs)

    def get_cat_ids(self, idx: int) -> list[int]:
        """Get category ids by index.

        Args:
            idx (int): Index of data.

        Returns:
            list[int]: Image categories of specified index.

        """
        data_info = self.get_data_info(idx)
        if 'gt_label' not in data_info:
            raise KeyError(f"'gt_label' not found in data_info for index {idx}")
        
        return data_info['gt_label']

In [3]:
complete_dataset = MultiLabelDataset("/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/dataset_dict.json")

In [4]:
complete_dataset[-1]

{'img_path': '/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/chips/chip_049_049.tif',
 'gt_label': [0, 1, 3, 4, 5, 6, 7],
 'sample_idx': 2435}

### Get the split indices for train, validation

In [None]:
# Create a csv out of the train_dataset
def save_to_csv(dataset, output_csv_path):
    with open(output_csv_path, mode='w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        # Write the header
        csv_writer.writerow(['chip_name', 'label', 'sample_idx'])

        # Iterate through the dataset and write the rows
        for sample in dataset:
            img_path = sample['img_path']
            chip_name = img_path.split('/')[-1]
            sample_idx = sample['sample_idx']
            label = sample['gt_label']
            csv_writer.writerow([chip_name, label, sample_idx])

In [None]:
save_to_csv(complete_dataset, "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/output.csv")

In [None]:
def plot_label_distribution(labels, title):
    label_freq = {}
    for label_list in labels:
        for label in label_list:
            if label in label_freq:
                label_freq[label] += 1
            else:
                label_freq[label] = 1

    # Plotting
    labels, freqs = zip(*sorted(label_freq.items()))
    plt.bar(labels, freqs)
    plt.xlabel('Label Index')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.xticks(labels)
    plt.show()


def save_sample_idx_to_txt(samples, output_txt_path):
    with open(output_txt_path, 'w') as f:
        for sample in samples:
            f.write(str(sample['sample_idx']) + '\n')


def split_and_plot_dataset(csv_path, split_ratio):
    # Read CSV
    df = pd.read_csv(csv_path)
    df['label'] = df['label'].apply(eval)  # Convert string representation of list to list
    all_samples = df.to_dict('records')

    # Shuffle and split
    random.shuffle(all_samples)
    split_index = int(len(all_samples) * split_ratio)
    selected_samples = all_samples[:split_index]
    remaining_samples = all_samples[split_index:]

    # Plot label distribution
    plot_label_distribution([sample['label'] for sample in selected_samples], 'Selected Sample Label Distribution')
    plot_label_distribution([sample['label'] for sample in remaining_samples], 'Remaining Sample Label Distribution')

    return selected_samples, remaining_samples

In [None]:
split_csv_file = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/output.csv"
selected_samples, remaining_samples = split_and_plot_dataset(split_csv_file, 0.8)

In [None]:
save_sample_idx_to_txt(selected_samples, 'train.txt')
save_sample_idx_to_txt(remaining_samples, 'val.txt')