In [2]:
from google.colab import files

uploaded = files.upload()

Saving kvasir-seg.zip to kvasir-seg (1).zip


In [3]:
!unzip -q "kvasir-seg (1).zip" -d /content/kvasir-seg

In [4]:
!pip install "flwr<1.11.0" numpy



In [5]:
import os
import glob

import flwr as fl
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Dict, Any, List, Tuple

from PIL import Image

  return datetime.utcnow().replace(tzinfo=utc)


In [6]:
CLIENT_ID = 2

SERVER_ADDRESS = "0.tcp.jp.ngrok.io:12014"

In [7]:
BASE_DATA_DIR = "/content/kvasir-seg/Kvasir-SEG"

In [8]:
class KvasirSegDataset(Dataset):
    """
    - images/ 안의 파일과 masks/ 안의 파일을
      '같은 이름 + (확장자만 다를 수 있음)' 기준으로 매칭
    """
    def __init__(self, images_dir: str, masks_dir: str):
        super().__init__()
        self.images_dir = images_dir
        self.masks_dir = masks_dir

        # 이미지 파일 후보들 (jpg/jpeg/png)
        img_paths = sorted(
            glob.glob(os.path.join(images_dir, "*.jpg"))
            + glob.glob(os.path.join(images_dir, "*.jpeg"))
            + glob.glob(os.path.join(images_dir, "*.png"))
        )

        if len(img_paths) == 0:
            raise RuntimeError(f"이미지 파일 없음: {images_dir}")

        self.pairs = []  # (img_path, mask_path) 리스트

        for img_path in img_paths:
            base = os.path.splitext(os.path.basename(img_path))[0]

            # 마스크는 png/jpg/jpeg 중 실제로 존재하는 걸 사용
            cand_masks = [
                os.path.join(masks_dir, base + ext)
                for ext in [".png", ".jpg", ".jpeg"]
            ]

            mask_path = None
            for cm in cand_masks:
                if os.path.exists(cm):
                    mask_path = cm
                    break

            if mask_path is None:
                # 해당 이미지에 대응되는 마스크가 없으면 그냥 스킵
                print(f"[WARN] 마스크 없음, 스킵: {img_path}")
                continue

            self.pairs.append((img_path, mask_path))

        if len(self.pairs) == 0:
            raise RuntimeError(f"매칭되는 이미지-마스크 쌍이 없습니다. masks 폴더 구조를 확인하세요.")

        print(f"[INFO] 유효한 이미지-마스크 쌍 개수: {len(self.pairs)}")

    def __len__(self):
        return len(self.pairs)

    def _load_and_preprocess(self, img_path, is_mask=False):
        img = Image.open(img_path)

        if is_mask:
            img = img.convert("L")
            img = img.resize((256, 256), resample=Image.NEAREST)
        else:
            img = img.convert("RGB")
            img = img.resize((256, 256), resample=Image.BILINEAR)
            img = img.convert("L")

        arr = np.array(img, dtype=np.float32)

        if arr.ndim == 2:
            arr = arr[None, :, :]
        else:
            arr = arr.transpose(2, 0, 1)

        arr = arr / 255.0
        return torch.tensor(arr, dtype=torch.float32)

    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]

        image = self._load_and_preprocess(img_path, is_mask=False)
        mask = self._load_and_preprocess(mask_path, is_mask=True)
        mask = (mask > 0.5).float()

        return image, mask

def dirichlet_split_dataset(full_dataset, num_clients=3, alpha=0.5, client_id=0):
    """
    Dirichlet(α) 기반으로 전체 데이터셋을 num_clients 개로 나눈 뒤,
    해당 client_id 에 할당된 인덱스만 반환.

    - alpha: Dirichlet α (논문에서는 1.0 / 0.5 / 0.1 사용)
    """
    np.random.seed(42)

    dataset_size = len(full_dataset)
    all_indices = np.arange(dataset_size)

    # 각 클라이언트 비율을 Dirichlet 분포에서 샘플링
    proportions = np.random.dirichlet([alpha] * num_clients)

    # 비율 * 전체 개수 → 각 클라이언트가 가질 샘플 수
    client_sizes = (proportions * dataset_size).astype(int)

    # 반올림/내림 과정에서 빠진 샘플을 마지막 클라이언트에 몰아주기
    client_sizes[-1] += dataset_size - client_sizes.sum()

    np.random.shuffle(all_indices)

    idx_map = {}
    start = 0
    for i in range(num_clients):
        end = start + client_sizes[i]
        idx_map[i] = all_indices[start:end]
        start = end

    client_indices = idx_map[client_id]
    print(f"[DIRICHLET] alpha={alpha}, client_id={client_id}, num_samples={len(client_indices)}")

    return client_indices


def make_dataloaders(alpha=0.5, num_clients=3) -> Tuple[DataLoader, DataLoader]:
    """
    이 클라이언트가 사용할 trainloader, valloader 생성
    - 논문과 동일하게 Dirichlet(α) 분포 기반 Non-IID 분할 적용
    - 그다음, 이 클라이언트의 데이터만 8:2로 train/val로 나눔
    """

    images_dir = os.path.join(BASE_DATA_DIR, "images")
    masks_dir = os.path.join(BASE_DATA_DIR, "masks")

    full_dataset = KvasirSegDataset(images_dir, masks_dir)

    client_indices = dirichlet_split_dataset(
        full_dataset=full_dataset,
        num_clients=num_clients,
        alpha=alpha,
        client_id=CLIENT_ID,
    )

    # 이 클라이언트 전용 Subset
    client_dataset = torch.utils.data.Subset(full_dataset, client_indices)

    # 클라이언트 내부에서 8:2 비율로 train/val 분할
    n_total = len(client_dataset)
    n_train = int(n_total * 0.8)
    n_val = n_total - n_train

    train_dataset, val_dataset = torch.utils.data.random_split(
        client_dataset,
        [n_train, n_val],
        generator=torch.Generator().manual_seed(42),  # 재현성
    )

    trainloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    return trainloader, valloader

In [9]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels: int = 1, out_channels: int = 1, features: int = 32):
        super().__init__()

        self.enc1 = self._block(in_channels, features)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = self._block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = self._block(features * 2, features * 4)

        self.up2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.dec2 = self._block(features * 4, features * 2)
        self.up1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.dec1 = self._block(features * 2, features)

        self.final_conv = nn.Conv2d(features, out_channels, kernel_size=1)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool1(x1))
        x3 = self.bottleneck(self.pool2(x2))

        x = self.up2(x3)
        x = torch.cat([x, x2], dim=1)
        x = self.dec2(x)

        x = self.up1(x)
        x = torch.cat([x, x1], dim=1)
        x = self.dec1(x)

        return self.final_conv(x)


def model_fn() -> nn.Module:
    """
    Flower 클라이언트에서 사용할 모델 생성 함수
    - 나중에 SimpleUNet 대신 논문 구조의 U-Net으로 교체하면 됨
    """
    return SimpleUNet(in_channels=1, out_channels=1, features=32)

In [10]:
def dice_loss(pred: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """
    세그멘테이션용 Dice Loss
    pred: (B,1,H,W), sigmoid 이후 값
    target: (B,1,H,W), 0/1 마스크
    """
    pred_flat = pred.view(pred.size(0), -1)
    target_flat = target.view(target.size(0), -1)

    inter = (pred_flat * target_flat).sum(dim=1)
    union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)

    dice = (2 * inter + eps) / (union + eps)
    return 1 - dice.mean()

In [11]:
class SimpleClient(fl.client.NumPyClient):
    def __init__(self):
        # 이 클라이언트의 ID (서버 쪽에서 구분용)
        self.cid = str(CLIENT_ID)

        # 모델, 데이터로더, 디바이스 준비
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model_fn().to(self.device)
        self.trainloader, self.valloader = make_dataloaders()

    # ---------- Flower 필수 메서드 1: get_parameters ----------
    def get_parameters(self, config: Dict[str, Any]) -> List[np.ndarray]:
        """
        서버가 "초기 파라미터 보내줘" 할 때 호출됨.
        모델의 state_dict를 numpy 배열 리스트로 변환해서 반환.
        """
        state_dict = self.model.state_dict()
        return [v.cpu().numpy() for _, v in state_dict.items()]

    # ---------- Flower 필수 메서드 2: set_parameters ----------
    def set_parameters(self, parameters: List[np.ndarray]) -> None:
        """
        서버에서 내려준 글로벌 파라미터를 로컬 모델에 로드.
        """
        state_dict = self.model.state_dict()
        new_state = {}
        for (k, _), p in zip(state_dict.items(), parameters):
            new_state[k] = torch.tensor(p)
        self.model.load_state_dict(new_state, strict=True)

    # ---------- Flower 필수 메서드 3: fit (로컬 학습) ----------
    def fit(
        self,
        parameters: List[np.ndarray],
        config: Dict[str, Any],
    ) -> Tuple[List[np.ndarray], int, Dict[str, Any]]:

        # 1) 글로벌 파라미터 적용
        self.set_parameters(parameters)
        self.model.to(self.device)
        self.model.train()

        # 2) 서버가 fit_config로 내려준 설정 사용
        local_epochs = int(config.get("local_epochs", 1))
        lr = float(config.get("learning_rate", 1e-3))
        batch_size = int(config.get("batch_size", 32))
        rnd = int(config.get("round", -1))

        print(f"[CLIENT {self.cid}] Round {rnd} | epochs={local_epochs}, lr={lr}, bs={batch_size}")

        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        running_loss = 0.0
        running_dice = 0.0
        num_batches = 0

        for epoch in range(local_epochs):
            for images, masks in self.trainloader:
                images = images.to(self.device)
                masks = masks.to(self.device)

                logits = self.model(images)
                probs = torch.sigmoid(logits)

                loss = dice_loss(probs, masks)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
                optimizer.step()

                running_loss += loss.item()
                running_dice += (1.0 - loss.item())
                num_batches += 1

        avg_loss = running_loss / max(1, num_batches)
        avg_dice = running_dice / max(1, num_batches)

        print(f"[CLIENT {self.cid}] Round {rnd} | train_loss={avg_loss:.4f}, train_dice={avg_dice:.4f}")

        # 3) 업데이트된 파라미터 반환
        state_dict = self.model.state_dict()
        new_params = [v.detach().cpu().numpy() for _, v in state_dict.items()]
        num_examples = len(self.trainloader.dataset)

        metrics = {
            "train_loss": float(avg_loss),
            "train_dice": float(avg_dice),
        }

        return new_params, num_examples, metrics

    # ---------- Flower 필수 메서드 4: evaluate (로컬 평가) ----------
    def evaluate(
        self,
        parameters: List[np.ndarray],
        config: Dict[str, Any],
    ) -> Tuple[float, int, Dict[str, Any]]:

        self.set_parameters(parameters)
        self.model.to(self.device)
        self.model.eval()

        total_loss = 0.0
        total_dice = 0.0
        num_batches = 0

        with torch.no_grad():
            for images, masks in self.valloader:
                images = images.to(self.device)
                masks = masks.to(self.device)

                logits = self.model(images)
                probs = torch.sigmoid(logits)

                l = dice_loss(probs, masks)
                total_loss += l.item()
                total_dice += (1.0 - l.item())
                num_batches += 1

        avg_loss = total_loss / max(1, num_batches)
        avg_dice = total_dice / max(1, num_batches)
        num_examples = len(self.valloader.dataset)

        print(f"[CLIENT {self.cid}] Eval | loss={avg_loss:.4f}, dice={avg_dice:.4f}")

        return float(avg_loss), num_examples, {"val_dice": float(avg_dice)}

In [12]:
print(f"[CLIENT {CLIENT_ID}] Connecting...") # 서버 연결 부분 수정X
fl.client.start_client(
    server_address=SERVER_ADDRESS,
    client=SimpleClient().to_client(),
)

DEBUG:flwr:Opened insecure gRPC connection (no certificates were passed)
  return datetime.utcnow().replace(tzinfo=utc)
DEBUG:flwr:ChannelConnectivity.IDLE
DEBUG:flwr:ChannelConnectivity.CONNECTING


[CLIENT 2] Connecting...
[INFO] 유효한 이미지-마스크 쌍 개수: 1000
[DIRICHLET] alpha=0.5, client_id=2, num_samples=30


DEBUG:flwr:ChannelConnectivity.READY
[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: train message 57bb6e79-b855-4750-87bd-eb310b09c302
INFO:flwr:Received: train message 57bb6e79-b855-4750-87bd-eb310b09c302


[CLIENT 2] Round 1 | epochs=1, lr=0.001, bs=32


[92mINFO [0m:      Sent reply
  return datetime.utcnow().replace(tzinfo=utc)
INFO:flwr:Sent reply


[CLIENT 2] Round 1 | train_loss=0.7719, train_dice=0.2281


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: evaluate message 93785f07-a713-47e6-8ef0-60b6a03798a5
INFO:flwr:Received: evaluate message 93785f07-a713-47e6-8ef0-60b6a03798a5
[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Eval | loss=0.6388, dice=0.3612


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: train message 7efa41ac-5fb7-4e5c-bcd9-fe30ccc5ce90
INFO:flwr:Received: train message 7efa41ac-5fb7-4e5c-bcd9-fe30ccc5ce90


[CLIENT 2] Round 2 | epochs=1, lr=0.001, bs=32


[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Round 2 | train_loss=0.6823, train_dice=0.3177


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: evaluate message b7eb2c16-d6f9-4d6f-a637-593ebe2fc1ea
INFO:flwr:Received: evaluate message b7eb2c16-d6f9-4d6f-a637-593ebe2fc1ea
[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Eval | loss=0.4657, dice=0.5343


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: train message 39c32ab6-193a-4fb2-95c4-8cddb97c5509
INFO:flwr:Received: train message 39c32ab6-193a-4fb2-95c4-8cddb97c5509


[CLIENT 2] Round 3 | epochs=1, lr=0.001, bs=32


[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Round 3 | train_loss=0.6501, train_dice=0.3499


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: evaluate message aad0ecb0-9a16-412e-ad9e-0fbc2e0997f2
INFO:flwr:Received: evaluate message aad0ecb0-9a16-412e-ad9e-0fbc2e0997f2
[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Eval | loss=0.4269, dice=0.5731


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: train message be9d0d89-740d-4930-8325-f6bf11bd5dcf
INFO:flwr:Received: train message be9d0d89-740d-4930-8325-f6bf11bd5dcf


[CLIENT 2] Round 4 | epochs=1, lr=0.001, bs=32


[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Round 4 | train_loss=0.6368, train_dice=0.3632


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: evaluate message c8d0a470-39f1-43a8-a556-9c63142bfa20
INFO:flwr:Received: evaluate message c8d0a470-39f1-43a8-a556-9c63142bfa20
[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Eval | loss=0.4262, dice=0.5738


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: train message c02f59d6-140b-4ac6-b11d-17ac568c95ba
INFO:flwr:Received: train message c02f59d6-140b-4ac6-b11d-17ac568c95ba


[CLIENT 2] Round 5 | epochs=1, lr=0.001, bs=32


[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Round 5 | train_loss=0.6395, train_dice=0.3605


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: evaluate message 82f2fff3-dae4-4816-8aa9-869b21e5fe86
INFO:flwr:Received: evaluate message 82f2fff3-dae4-4816-8aa9-869b21e5fe86
[92mINFO [0m:      Sent reply
INFO:flwr:Sent reply


[CLIENT 2] Eval | loss=0.4840, dice=0.5160


[92mINFO [0m:      
INFO:flwr:
[92mINFO [0m:      Received: reconnect message b2578d99-dd97-4a7f-b914-3930443c1cb2
INFO:flwr:Received: reconnect message b2578d99-dd97-4a7f-b914-3930443c1cb2
DEBUG:flwr:gRPC channel closed
[92mINFO [0m:      Disconnect and shut down
INFO:flwr:Disconnect and shut down
  return datetime.utcnow().replace(tzinfo=utc)
