In [1]:
# https://poutyne.org/examples/semantic_segmentation.html
# 
from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model

import os
import pathlib
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchmetrics

# Poutyne Model on GPU
from poutyne import Model

# Custom modules
import utils2
import utils
import datasets
import networks
import callbacks

from dotenv import load_dotenv
load_dotenv(os.path.join(pathlib.Path(".").parent.absolute(), '.env_comet'))




True

In [15]:
utils.set_seeds(42)

def init_arguments():
    parser = argparse.ArgumentParser(description='test for segmentation')
    parser.add_argument('--model_name', type=str, default='segformer', help='model name (default: segformer)')
    parser.add_argument('--dataset', type=str, default='voc2012', help='dataset (default: voc2012)')
    parser.add_argument('--data_path', type=str, default='~/.data', help='data path (default: data)')
    parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train (default: 100)')
    parser.add_argument('--learning_rate', type=float, default=0.0005, help='learning rate (default: 0.0005)')
    parser.add_argument('--image_size', type=int, default=224, help='image size (default: 224)')
    parser.add_argument('--out_image_size', type=int, default=56, help='number of workers (default: 2)')
    parser.add_argument('--num_classes', type=int, default=22, help='number of classes (default: 22)')
    parser.add_argument('--continue_training', action='store_true', help='continue training (default: False)')
    parser.add_argument('--dirichlet_alpha', type=float, default=1.0, help='dirichlet alpha (default: 1.0)')
    parser.add_argument('--num_clients', type=int, default=10, help='number of clients (default: 10)')
    args = parser.parse_args("")
    return args

def init_experiment(dataset, model_name):
    experiment = Experiment(
        api_key = os.getenv('COMET_API_TOKEN'),
        project_name = os.getenv('COMET_PROJECT_NAME'),
        workspace= os.getenv('COMET_WORKSPACE'),
    )
    experiment.add_tag(dataset)
    experiment.add_tag(model_name)
    experiment.set_name(f"{dataset}-{model_name}")
    return experiment

args = init_arguments()
experiment = init_experiment(args.dataset, args.model_name)


[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/neighborheo/test-segmentation/2dec92f882d445c392f1f5fca1bc0a46
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : voc2012-segformer
[1;38;5;39mCOMET INFO:[0m   Uploads:
[1;38;5;39mCOMET INFO:[0m     environment details      : 1
[1;38;5;39mCOMET INFO:[0m     filename                 : 1
[1;38;5;39mCOMET INFO:[0m     git metadata             : 1
[1;38;5;39mCOMET INFO:[0m     git-patch (uncompressed) : 1 (1.03 KB)
[1;38;5;39mCOMET INFO:[0m     installed packages       : 1
[1;38;5;39mCOMET INFO:[0m 

In [17]:
train_dataset, valid_dataset = datasets.getVOCSegDatasets(output_size=None)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
test_loader = valid_loader

# Creating saving directory
save_path = f'saves/{args.model_name}-{args.dataset}-comet_test'
os.makedirs(save_path, exist_ok=True)

In [18]:
import numpy as np
from PIL import Image
def get_labels(filepath):
    arr = np.array(Image.open(filepath))
    arr[arr>20] = 0
    unique_list = np.unique(arr)
    unique_list = unique_list[unique_list!=0]
    # sorting
    unique_list.sort()
    return list(unique_list)

labels = [get_labels(filepath) for filepath in train_dataset.masks]

In [5]:
def create_label_to_id_map(labels):
    label_to_id = {}
    index = 0
    for label in labels:
        label = frozenset(label)
        if label not in label_to_id:
            label_to_id[label] = index
            index += 1
    return label_to_id

def convert_id_to_label_map(label_to_id):
    return {v: k for k, v in label_to_id.items()}

def convert_labels_to_ids(labels, label_to_id):
    return [label_to_id[frozenset(label)] for label in labels]

label_to_id = create_label_to_id_map(labels)
id_to_label = convert_id_to_label_map(label_to_id)

label_ids = convert_labels_to_ids(labels, label_to_id)

N_class = len(label_to_id)
N_parties = args.num_clients
y_data = label_ids
alpha = args.dirichlet_alpha
utils2.set_random_seed(42)
dirichlet_count = utils2.get_dirichlet_distribution_count(N_class, N_parties, y_data, alpha)
split_dirichlet_data_index_dict = utils2.get_split_data_index(y_data, dirichlet_count)

party_id: 0, num of samples: 137
party_id: 1, num of samples: 140
party_id: 2, num of samples: 53
party_id: 3, num of samples: 130
party_id: 4, num of samples: 129
party_id: 5, num of samples: 118
party_id: 6, num of samples: 117
party_id: 7, num of samples: 68
party_id: 8, num of samples: 153
party_id: 9, num of samples: 82


In [21]:
split_dirichlet_data_index_dict
project_dir = pathlib.Path(".").parent.absolute()
split_path = project_dir / "splitfile" / f"{args.dataset}" 
split_path.mkdir(parents=True, exist_ok=True)

# save split files
import json
with open(split_path / f'dirichlet_{args.dirichlet_alpha}_for_{args.num_clients}_clients', "w") as f:
    json.dump(split_dirichlet_data_index_dict, f)

In [22]:
# load split files
with open(split_path / f'dirichlet_{args.dirichlet_alpha}_for_{args.num_clients}_clients', "r") as f:
    split_dirichlet_data_index_dict = json.load(f)
    
# subset of train dataset
train_dataset = torch.utils.data.Subset(train_dataset, split_dirichlet_data_index_dict["2"])
print(len(train_dataset))

53
