In [None]:
# https://poutyne.org/examples/semantic_segmentation.html
# 
from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model

import sys
sys.path.append("/home/suncheol/code/FedTest/0_FedMHAD_Seg")
import os
import pathlib
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics

# Poutyne Model on GPU
from poutyne import Model

# # Custom modules
# import utils2
import utils
import datasets
# import networks
# import callbacks

from dotenv import load_dotenv
load_dotenv(os.path.join(pathlib.Path(".").absolute().parent, '.env_comet'))


In [None]:
utils.set_random_seed(42)

def init_arguments():
    parser = argparse.ArgumentParser(description='test for segmentation')
    parser.add_argument('--model_name', type=str, default='segformer', help='model name (default: segformer)')
    parser.add_argument('--dataset', type=str, default='voc2012', help='dataset (default: voc2012)')
    parser.add_argument('--data_path', type=str, default='~/.data', help='data path (default: data)')
    parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
    parser.add_argument('--learning_rate', type=float, default=0.0005, help='learning rate (default: 0.0005)')
    parser.add_argument('--image_size', type=int, default=224, help='image size (default: 224)')
    parser.add_argument('--out_image_size', type=int, default=56, help='number of workers (default: 2)')
    parser.add_argument('--num_classes', type=int, default=3, help='number of classes (default: 22)')
    parser.add_argument('--continue_training', action='store_true', help='continue training (default: False)')
    parser.add_argument('--dirichlet_alpha', type=float, default=0.1, help='dirichlet alpha (default: 1.0)')
    parser.add_argument('--num_clients', type=int, default=3, help='number of clients (default: 10)')
    parser.add_argument('--malicious', type=int, default=0, help='number of malicious clients (default: 0)')
    args = parser.parse_args("")
    return args

def init_experiment(dataset, model_name):
    experiment = Experiment(
        api_key = os.getenv('COMET_API_TOKEN'),
        project_name = os.getenv('COMET_PROJECT_NAME'),
        workspace= os.getenv('COMET_WORKSPACE'),
    )
    experiment.add_tag(dataset)
    experiment.add_tag(model_name)
    experiment.set_name(f"{dataset}-{model_name}")
    return experiment

args = init_arguments()
experiment = init_experiment(args.dataset, args.model_name)


In [None]:
import datasets
datasetpartition = datasets.PascalVocSegmentationPartition(args)
train_dataset, valid_dataset = datasetpartition.load_partition(-1)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0 )
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0 )
test_loader = valid_loader

# Creating saving directory
save_path = f'saves/{args.model_name}-{args.dataset}-comet_test'
os.makedirs(save_path, exist_ok=True)

In [None]:
# pascal voc dataset :
class_names = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
                'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
                'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
 

In [None]:
import numpy as np
from PIL import Image
def get_labels(filepath):
    arr = np.array(Image.open(filepath))
    arr[arr>20] = 0
    # for i in range(1, 21):
    #     # only use 1, 9, 12
    #     if i not in [1, 9, 12]:
    #         arr[arr==i] = 0
    unique_list = np.unique(arr)
    unique_list = unique_list[unique_list!=0]
    # sorting
    unique_list.sort()
    return list(unique_list)

labels = [get_labels(filepath) for filepath in train_dataset.masks]

In [None]:
len(labels)

In [None]:
test_labels = [get_labels(filepath) for filepath in valid_dataset.masks]
test_not_empty_indices = [i for i, label in enumerate(test_labels) if len(label) > 0]

In [None]:
test_not_empty_indices

In [None]:
def create_label_to_id_map(labels):
    label_to_id = {}
    index = 0
    for label in labels:
        label = frozenset(label)
        if label not in label_to_id:
            label_to_id[label] = index
            index += 1
    return label_to_id

def convert_id_to_label_map(label_to_id):
    return {v: k for k, v in label_to_id.items()}

def convert_labels_to_ids(labels, label_to_id):
    return [label_to_id[frozenset(label)] for label in labels]

label_to_id = create_label_to_id_map(labels)
id_to_label = convert_id_to_label_map(label_to_id)

label_ids = convert_labels_to_ids(labels, label_to_id)

N_class = len(label_to_id)
N_parties = args.num_clients
y_data = label_ids
dirichlet_count = utils.get_dirichlet_distribution_count(N_class, N_parties, y_data, args.dirichlet_alpha)
split_dirichlet_data_index_dict = utils.get_split_data_index(y_data, dirichlet_count)

In [None]:
indices = {}
for i in range(N_parties):
    index_list = split_dirichlet_data_index_dict[i]
    index = []
    for k in index_list:
        if len(labels[k]) != 0:
            index.append(k)
    indices[i] = index
    print(f"party {len(indices)}, index : {len(index)}")

In [None]:
project_dir = pathlib.Path(".").absolute().parent
split_path = project_dir / "splitfile" / f"{args.dataset}_{args.num_clients}_clients" 
split_path.mkdir(parents=True, exist_ok=True)

# save split files
import json
with open(split_path / f'dirichlet_{args.dirichlet_alpha}_for_{args.num_clients}_clients', "w") as f:
    json.dump(indices, f)

In [None]:
# load split files
with open(split_path / f'dirichlet_{args.dirichlet_alpha}_for_{args.num_clients}_clients', "r") as f:
    split_dirichlet_data_index_dict = json.load(f)
    
# subset of train dataset
train_dataset = torch.utils.data.Subset(train_dataset, split_dirichlet_data_index_dict["2"])
print(len(train_dataset))