In [6]:
import torch
from dataset import ClientDataset
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset, DataLoader, Dataset 
from sklearn.model_selection import train_test_split
from torchvision import transforms
import os
from glob import glob
from PIL import Image
class ImagePathsDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths,class_indexes, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        self.class_indexes=class_indexes

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        class_name = os.path.basename(os.path.dirname(image_path))
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image, int(self.class_indexes[class_name])  # class_name'i doğrudan döndür

    def __len__(self):
        return len(self.image_paths)


class ClientDataset():
    def __init__(self, folder_path, train_ratio=0.8):
        self.client_folder_paths = folder_path
        self.train_ratio = train_ratio
        self.class_names = sorted(os.listdir(folder_path))
        self.transform = transforms.Compose([
            transforms.Resize((160, 160)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.class_indexes=self.get_class_names()
    def get_class_names(self):
        class_indexes = {}
        for idx, class_name in enumerate(self.class_names):
            class_indexes[class_name] = f"{idx+105}"
        return class_indexes

    def split_images(self, folder_path, class_names):
        test_list_full = []

        for class_name in class_names:
            class_folder = os.path.join(folder_path, class_name)
            image_paths = glob(class_folder+"/*")
            test_list_full += image_paths
        return test_list_full
    
    def get_num_classes(self):
        return len(self.class_names)

    def load_client_data(self, batch_size=32):
        test_list = self.split_images(self.client_folder_paths, self.class_names)
        test_dataset = ImagePathsDataset(test_list,self.class_indexes, transform=self.transform)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        return test_loader
    
client_id=1
client_path=f"./celeba_500/celeba/custom_celeba_500/clients_test/client_{client_id}"
client_dataset=ClientDataset(client_path)
testloader=client_dataset.load_client_data()
new_num_classes=client_dataset.get_num_classes()

In [24]:
client_dict=client_dataset.get_class_names()
pred_to_label={v:k for k,v in client_dict.items()}

# Global model yükle


In [11]:
from model import get_client_model, load_centralized_model
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
cent_net=load_centralized_model()
updated_global_model=torch.load(f"./models/run_2024_05_19-04_37/global_models/model_round_1.pth",map_location=device)
cent_net.load_state_dict(updated_global_model)

Global Model loaded successfully


<All keys matched successfully>

In [12]:
cent_net.fc

Linear(in_features=1280, out_features=105, bias=True)

In [13]:
from model import get_client_model, load_centralized_model
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
client_net=get_client_model(client_id,new_num_classes)
updated_global_model=torch.load(f"./models/run_2024_05_19-04_37/client_models/clientid_{client_id}/model_round_1.pth",map_location=device)
client_net.load_state_dict(updated_global_model)

Global Model loaded successfully


<All keys matched successfully>

In [14]:
client_net.fc

Linear(in_features=1280, out_features=153, bias=True)

In [17]:
test_client_net=get_client_model(client_id,new_num_classes)
# Centralized modelin features ve fc layerının [:105] kısmını al
cent_features_weights = cent_net.features.state_dict()
# cent_fc_weights = {k: v for k, v in cent_net.fc.state_dict().items() if int(k.split('.')[-1]) < 105}

# # Client modelin fc layerının [105:] kısmını al
# client_fc_weights = {k: v for k, v in client_net.fc.state_dict().items() if int(k.split('.')[-1]) >= 105}

# Yeni modelin ağırlıklarını yükle
test_client_net.features.load_state_dict(cent_features_weights)
with torch.no_grad():
    test_client_net.fc.weight[:105].copy_(cent_net.fc.weight[:105])
    test_client_net.fc.bias[:105].copy_(cent_net.fc.bias[:105])
    test_client_net.fc.weight[105:].copy_(client_net.fc.weight[105:])
    
test_client_net.to(device)


Global Model loaded successfully


ClientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e

In [26]:
import numpy as np
import matplotlib.pyplot as plt
class_names = client_dataset.get_class_names()
for images, labels in testloader:
    images, labels = images.to(device), labels.to(device)
    outputs = test_client_net(images)
    _, predicted = torch.max(outputs, 1)
    print("Predicted: ", [pred_to_label[p.item()] for p in predicted])
    break


KeyError: 30