In [None]:
import argparse
import os
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import DataLoader
import albumentations as A
from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation
import evaluate
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from custom_datasets import ImageSegmentationDataset
from utils import color_palette
import copy

# def parse_args():
#     parser = argparse.ArgumentParser(description="Train MaskFormer model for instance segmentation")
#     parser.add_argument("--train_csv", type=str, default='./outputs', help="Path to .csv file with rows for all train")
#     parser.add_argument("--val_csv", type=str, default='./outputs', help="Path to .csv file with rows for all val")
#     parser.add_argument("--csv_img_path_col", type=str, default='image', help="Column name in the csv for the path to the image")
#     parser.add_argument("--csv_label_path_col", type=str, default='label', help="Column name in the csv for the path to the segmentation label")
#     parser.add_argument("--output_directory", type=str, default='./outputs', help="Desired path for output files (model, val inferences, etc)")
#     parser.add_argument('--dataset_mean', nargs='+', type=float, help='Array of float values for mean i.e. 0.709 0.439 0.287')
#     parser.add_argument('--dataset_std', nargs='+', type=float, help='Array of float values for std i.e. 0.210 0.220 0.199')
#     parser.add_argument("--lr", type=float, default=0.00003, help="Learning rate for the optimizer")
#     parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training and testing")
#     parser.add_argument('--jitters', nargs='+', type=float, help='Array of float jitter values: brightness, contrast, saturation, hue, probability')
#     parser.add_argument("--num_epochs", type=int, default=50, help="Max number of epochs to train")
#     parser.add_argument("--patience", type=int, default=5, help="Early stopping")
#     parser.add_argument("--num_val_outputs_to_save", type=int, default=3, help="Number of examples from val to save, so you can see your model improve on it during training.")
#     parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for dataloaders")
#     return parser.parse_args()

# Federated Learning Datasets
dataset_dirs = [
    "/sddata/data/retina_datasets_preprocessed/federated_learning_public/binrushed",
    "/sddata/data/retina_datasets_preprocessed/federated_learning_public/drishti",
    "/sddata/data/retina_datasets_preprocessed/federated_learning_public/magrabi"
]

def load_dataset(dataset_dir, transform):
    image_dir = os.path.join(dataset_dir, "images")
    label_dir = os.path.join(dataset_dir, "labels")
    image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir)]
    label_paths = [os.path.join(label_dir, lbl) for lbl in os.listdir(label_dir)]
    return ImageSegmentationDataset(image_paths, label_paths, transform=transform)

def FedAvg(weights):
    global_model = copy.deepcopy(weights[0])
    for key in global_model.keys():
        for i in range(1, len(weights)):
            global_model[key] += weights[i][key]
        global_model[key] = torch.div(global_model[key], len(weights))
    return global_model

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

def main():
    # args = parse_args()
    # ipynb values 
    train_csv = '/home/thakuriu/fl_glaucoma_seg/csvs/binrushed_train.csv' 
    val_csv = '/home/thakuriu/fl_glaucoma_seg/csvs/binrushed_val.csv' 
    csv_img_path_col  = 'image_path'
    csv_label_path_col  = 'label_path'
    output_directory = '/home/thakuriu/fl_glaucoma_seg/detection_segmentation_v2/segmentation_train_and_inference/train_outputs'
    dataset_mean=[0.768, 0.476, 0.289]
    dataset_std = [0.221, 0.198, 0.165]
    lr = 0.00003 
    batch_size = 8 
    jitters = [0.2, 0.2, 0.05, 0.05, 0.75] 
    num_epochs = 100 
    patience = 7 
    num_val_outputs_to_save = 5 
    num_workers = 0 # 16

    device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    # [Your existing code for initializing output folders, transforms, etc.]

    # Prepare datasets and dataloaders
    datasets = [load_dataset(dir, train_transform) for dir in dataset_dirs]
    dataloaders = [DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) for ds in datasets]

    # Initialize models, optimizers, and criterion
    models = [Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-cityscapes-semantic",
                                                                 id2label=id2label,
                                                                 ignore_mismatched_sizes=True).to(device) for _ in dataset_dirs]
    optimizers = [optim.AdamW(model.parameters(), lr=lr) for model in models]
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        weights = []

        # Train each model on its dataset
        for i, (model, dataloader, optimizer) in enumerate(zip(models, dataloaders, optimizers)):
            loss = train_one_epoch(model, dataloader, optimizer, criterion, device)
            print(f"Training Loss for Dataset {i+1}: {loss}")
            weights.append(model.state_dict())

        # Federated averaging
        global_weights = FedAvg(weights)
        for model in models:
            model.load_state_dict(global_weights)

        # Validation code can be added here

    # Save the global model
    torch.save(global_weights, os.path.join(model_directory, 'global_model.pth'))

if __name__ == "__main__":
    main()
