### Check if GPU is available and main packages are installed correctly

In [None]:
import torch
torch.cuda.is_available()
#torch.cuda.device_count()

In [None]:
import urllib3
import charset_normalizer
import requests

print(f'urllib3 version: {urllib3.__version__}')
print(f'charset_normalizer version: {charset_normalizer.__version__}')
print(f'requests version: {requests.__version__}')

In [None]:
import mmcv
print(mmcv.__version__)

In [None]:
from mmpretrain.models.heads import MultiLabelClsHead
from mmpretrain.models.builder import BACKBONES, HEADS

In [None]:
import mmengine
from mmengine.runner

### Convert the csv file to a dictionary of standard data format of MMLab

In [None]:
import csv
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random


def csv_to_dict(src_dir, in_csv_path, out_json_path):
    
    result_dict = {"metainfo": {}, "data_list": []}
    
    with open(in_csv_path, 'r') as csvfile:
        reader = csv.DictReader(csvfile)
        
        # Extract column names to form classes, skipping '', 'center_lon', 'center_lat', and 'chip_id'
        result_dict["metainfo"]["classes"] = [field for field in reader.fieldnames if field not in ["", "center_lon", "center_lat", "chip_id"]]
        
        # Convert each row to the required dictionary format
        for row in reader:
            data_item = {}
            
            # Construct image path
            data_item["img_path"] = f"{src_dir}/chip_{row['chip_id']}.tif"
            
            # Construct ground truth labels
            gt_labels = [i for i, class_name in enumerate(result_dict["metainfo"]["classes"]) if int(row[class_name]) == 1]
            data_item["gt_label"] = gt_labels
            
            # Append to data list
            result_dict["data_list"].append(data_item)
            
    # Save the result dictionary as a JSON file
    with open(out_json_path, "w") as jsonfile:
        json.dump(result_dict, jsonfile, indent=4)


In [None]:
# Test the function
src_dir = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/chips"
in_csv_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/CEO_plot_id.csv"  
out_json_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/dataset_dict.json"
csv_to_dict(src_dir, in_csv_path, out_json_path)

In [None]:
class CSVConverter:
    
    def __init__(self, src_dir, in_csv_path):
        self.src_dir = src_dir
        self.in_csv_path = in_csv_path
        self.data = self._parse_csv()
        
    def _parse_csv(self):
        with open(self.in_csv_path, "r") as f:
            reader = csv.reader(f)
            rows = list(reader)
        
        header = rows[0]
        classes = header[3:-1]
        
        data_list = []
        for row in rows[1:]:
            img_path = f"{self.src_dir}/chip_{row[-1]}.tif"
            gt_label = [i for i, value in enumerate(row[3:-1]) if int(value) == 1]
            data_list.append({
                "img_path": img_path,
                "gt_label": gt_label
            })
        
        return {
            "metainfo": {
                "classes": classes
            },
            "data_list": data_list
        }
    
    def get_data(self):
        return self.data

In [None]:
src_dir = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/chips"
in_csv_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/CEO_plot_id.csv" 

In [None]:
converter = CSVConverter(src_dir, in_csv_path)
print(converter.get_data())

### read from package

In [None]:
from geospatial_fm import MultiLabelGeospatialDataset
from geospatial_fm import LoadGeospatialImageFromFile

In [None]:
ann_file = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/dataset_dict.json"
train_split = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data_splits/multi_label_classification/train.txt"
val_split = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data_splits/multi_label_classification/val.txt"
train_dataset = MultiLabelGeospatialDataset(ann_file, train_split)

In [None]:
len(train_dataset)

### Custom Dataset

In [None]:
from mmpretrain.registry import DATASETS
from mmpretrain.datasets.base_dataset import BaseDataset
from typing import List, Any, Union, Optional

#@DATASETS.register_module()
class MultiLabelDataset(BaseDataset):
    """Multi-label Dataset for image classification.

    This dataset extends BaseDataset to support multi-label classification.

    Args:
        ann_file (str): Annotation file path.
        metainfo (dict, optional): Meta information for dataset, such as class information.
        ... (Other arguments inherited from BaseDataset)

    """

    def __init__(self, 
                 ann_file: str,
                 metainfo: Optional[dict] = None,
                 **kwargs: Any):
        # Custom checks or operations for ann_file can go here
        if not ann_file.endswith('.json'):
            raise ValueError("Annotation file must be a .json file")
        
        # Call the parent class's init method
        super().__init__(ann_file=ann_file, metainfo=metainfo, **kwargs)

    def get_cat_ids(self, idx: int) -> list[int]:
        """Get category ids by index.

        Args:
            idx (int): Index of data.

        Returns:
            list[int]: Image categories of specified index.

        """
        data_info = self.get_data_info(idx)
        if 'gt_label' not in data_info:
            raise KeyError(f"'gt_label' not found in data_info for index {idx}")
        
        return data_info['gt_label']

In [None]:
annotation_file = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/dataset_dict.json"
complete_dataset = MultiLabelDataset(annotation_file)

In [None]:
complete_dataset[-1]

### Get the split indices for train, validation

#### CSV output

In [None]:
# Create a csv out of the train_dataset
def save_to_csv(dataset, output_csv_path):
    with open(output_csv_path, mode='w', newline='') as csv_file:
        csv_writer = csv.writer(csv_file)
        # Write the header
        csv_writer.writerow(['chip_name', 'label', 'sample_idx'])

        # Iterate through the dataset and write the rows
        for sample in dataset:
            img_path = sample['img_path']
            chip_name = img_path.split('/')[-1]
            sample_idx = sample['sample_idx']
            label = sample['gt_label']
            csv_writer.writerow([chip_name, label, sample_idx])

In [None]:
save_to_csv(complete_dataset, "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/output.csv")

#### txt output

In [None]:
def plot_label_distribution(labels, title):
    label_freq = {}
    for label_list in labels:
        for label in label_list:
            if label in label_freq:
                label_freq[label] += 1
            else:
                label_freq[label] = 1

    # Plotting
    labels, freqs = zip(*sorted(label_freq.items()))
    plt.bar(labels, freqs)
    plt.xlabel('Label Index')
    plt.ylabel('Frequency')
    plt.title(title)
    plt.xticks(labels)
    plt.show()


def save_sample_idx_to_txt(samples, output_txt_path):
    with open(output_txt_path, 'w') as f:
        for sample in samples:
            f.write(str(sample['sample_idx']) + '\n')


def split_and_plot_dataset(csv_path, split_ratio):
    # Read CSV
    df = pd.read_csv(csv_path)
    df['label'] = df['label'].apply(eval)  # Convert string representation of list to list
    all_samples = df.to_dict('records')

    # Shuffle and split
    random.shuffle(all_samples)
    split_index = int(len(all_samples) * split_ratio)
    selected_samples = all_samples[:split_index]
    remaining_samples = all_samples[split_index:]

    # Plot label distribution
    plot_label_distribution([sample['label'] for sample in selected_samples], 'Selected Sample Label Distribution')
    plot_label_distribution([sample['label'] for sample in remaining_samples], 'Remaining Sample Label Distribution')

    return selected_samples, remaining_samples

In [None]:
split_csv_file = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/output.csv"
selected_samples, remaining_samples = split_and_plot_dataset(split_csv_file, 0.8)

In [None]:
save_sample_idx_to_txt(selected_samples, 'train.txt')
save_sample_idx_to_txt(remaining_samples, 'val.txt')

### Calculate mean and standard deviation per band over the whole dataset

In [None]:
import os
import glob
import xarray as xr
import numpy as np
import pandas as pd

def calculate_mean_std_and_save_as_csv(src_dir, file_pattern='chip*.tif', output_csv='result.csv'):
    
    # Find all TIFF files in the specified directory
    file_paths = glob.glob(os.path.join(src_dir, file_pattern))

    if not file_paths:
        print(f"No TIFF files found in the directory '{src_dir}' with the pattern '{file_pattern}'")
        return

    # Open and concatenate all TIFF files using xarray
    dataset = xr.open_mfdataset(file_paths)

    # Calculate mean and standard deviation for each band
    band_means = dataset.mean(dim=['x', 'y'])
    band_stds = dataset.std(dim=['x', 'y'])

    # Create a DataFrame to store the results
    result_df = pd.DataFrame({'Band': band_means.band.values, 'Mean': band_means.values, 'StdDev': band_stds.values})

    # Print the results to the screen
    print("Mean and Standard Deviation for Each Band:")
    print(result_df)

    # Save the results as a CSV file
    result_df.to_csv(output_csv, index=False)
    print(f"Results saved to '{output_csv}'")


In [None]:
# Example usage:
directory = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/chips"
output_csv='/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/result.csv'
calculate_mean_std_and_save_as_csv(directory, output_csv='result.csv')

#### Calculate global class weights

In [None]:
import json
import numpy as np

def calculate_global_class_weights_from_json(json_file_path, ignore_index=0):
    """
    Calculate class weights based on the global class distribution from a JSON file.

    This function calculates the class weights based on the global class distribution
    over the entire dataset based on the inverse frequency ratio formula.

    Parameters:
    json_file_path (str): Path to the JSON file containing "gt_label" for each sample.
    ignore_index (int, optional): The class index to be ignored when calculating 
                                  the class weights. Defaults to 0.

    Returns:
    numpy array: The class weights for the dataset as an array.
    """
    # Load the JSON file
    with open(json_file_path, 'r') as json_file:
        data = json.load(json_file)

    # Extract the list of "gt_label" arrays from each sample
    gt_labels = [sample["gt_label"] for sample in data["data_list"]]

    # Calculate the global counts for each class
    num_classes = len(data["metainfo"]["classes"])
    global_counts = np.zeros(num_classes)

    for labels in gt_labels:
        unique, unique_counts = np.unique(labels, return_counts=True)

        # Add the unique_counts from this sample to the global_counts
        for u, uc in zip(unique, unique_counts):
            global_counts[u] += uc

    # Ignore the class specified by ignore_index when calculating the ratio and weight
    valid_indices = np.arange(num_classes) != ignore_index
    valid_counts = global_counts[valid_indices]
    ratio = valid_counts.astype(float) / np.sum(valid_counts)
    weights = (1. / ratio) / np.sum(1. / ratio)
    
    return weights

In [None]:
json_file_path = "/mnt/c/My_documents/summer_project/task2_gfm/hls-foundation-os/data/annotations/dataset_dict.json"
ignore_index = -100

class_weights = calculate_global_class_weights_from_json(json_file_path, ignore_index)
print(class_weights)


In [None]:
! mim train mmpretrain --launcher pytorch configs/multi_label_classification.py