In [1]:
from SimSiam.fedfastsiam.client import *
from SimSiam.fedfastsiam.datapreparation import *
from SimSiam.fedfastsiam.server import *
from torch.utils.data import Subset


In [2]:
def load_data_iid(trainset, num_clients, batch_size):
    shuffled_indices = torch.randperm(len(trainset))
    # Get the total number of indices
    num_indices = len(shuffled_indices)

    # Determine the number of indices per client
    indices_per_client = num_indices // num_clients

    # Distribute indices among clients
    local_dataloaders = []
    for i in range(num_clients):
        start = i * indices_per_client
        if i == num_clients - 1:  # Last client gets remaining indices
            end = num_indices
        else:
            end = start + indices_per_client
        client_indices = shuffled_indices[start:end]

        # create subset of dataset
        subset = Subset(trainset, client_indices)
        # create dataloader 
        local_dataloder = torch.utils.data.DataLoader(subset, batch_size=batch_size)

        local_dataloaders.append(local_dataloder)
    return local_dataloaders

In [3]:
def create_datasets(num_clients, dataset_size=5_000, batch_size=64, num_views=4):
    """Split the whole dataset in IID or non-IID manner for distributing to clients."""
    augmentation = [
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        #normalize
    ]
    #train_set, val_set = torch.utils.data.random_split(trainset, [45000, 5000])
    trainset = ZenseactDataset(dataset_size, transform=TwoCropsTransform(transforms.Compose(augmentation), num_views))

    local_dataloaders = load_data_iid(trainset, num_clients, batch_size)

    return local_dataloaders

In [4]:
local_dataloaders = create_datasets(5)
local_dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x7f0e08389e70>,
 <torch.utils.data.dataloader.DataLoader at 0x7f0e0893e800>,
 <torch.utils.data.dataloader.DataLoader at 0x7f0e0893ce50>,
 <torch.utils.data.dataloader.DataLoader at 0x7f0ee33e1f90>,
 <torch.utils.data.dataloader.DataLoader at 0x7f0e0808fd00>]

In [5]:
trainset = ZenseactDataset(5000)
num_clients=5
batch_size=32

shuffled_indices = torch.randperm(len(trainset))
print(shuffled_indices)
    # Get the total number of indices
num_indices = len(shuffled_indices)

# Determine the number of indices per client
indices_per_client = num_indices // num_clients

# Distribute indices among clients
local_dataloaders = []
for i in range(num_clients):
    start = i * indices_per_client
    if i == num_clients - 1:  # Last client gets remaining indices
        end = num_indices
    else:
        end = start + indices_per_client
    client_indices = shuffled_indices[start:end]

    # create subset of dataset
    subset = Subset(trainset, client_indices)
    # create dataloader 
    local_dataloder = torch.utils.data.DataLoader(subset, batch_size=batch_size)

    local_dataloaders.append(local_dataloder)

tensor([3152, 2680, 3892,  ..., 3555, 4714, 1882])


In [6]:
len(local_dataloaders[0])

32

In [7]:
trainset = ZenseactDataset(5000)
trainset.folders

['../../../mnt/nfs_mount/single_frames/000000',
 '../../../mnt/nfs_mount/single_frames/000001',
 '../../../mnt/nfs_mount/single_frames/000002',
 '../../../mnt/nfs_mount/single_frames/000003',
 '../../../mnt/nfs_mount/single_frames/000004',
 '../../../mnt/nfs_mount/single_frames/000005',
 '../../../mnt/nfs_mount/single_frames/000006',
 '../../../mnt/nfs_mount/single_frames/000007',
 '../../../mnt/nfs_mount/single_frames/000008',
 '../../../mnt/nfs_mount/single_frames/000009',
 '../../../mnt/nfs_mount/single_frames/000010',
 '../../../mnt/nfs_mount/single_frames/000011',
 '../../../mnt/nfs_mount/single_frames/000012',
 '../../../mnt/nfs_mount/single_frames/000013',
 '../../../mnt/nfs_mount/single_frames/000014',
 '../../../mnt/nfs_mount/single_frames/000015',
 '../../../mnt/nfs_mount/single_frames/000016',
 '../../../mnt/nfs_mount/single_frames/000017',
 '../../../mnt/nfs_mount/single_frames/000018',
 '../../../mnt/nfs_mount/single_frames/000019',
 '../../../mnt/nfs_mount/single_frames/0

In [5]:
server = Server(num_clients=5, output_path="", num_rounds=5, 
                local_epochs=5, batch_size=32)
server

<SimSiam.fedfastsiam.server.Server at 0x7f8b818a3bb0>

In [6]:
server.setup()

[<torch.utils.data.dataloader.DataLoader object at 0x7f8b812fd450>, <torch.utils.data.dataloader.DataLoader object at 0x7f8b812fd5a0>, <torch.utils.data.dataloader.DataLoader object at 0x7f8b812fd6f0>, <torch.utils.data.dataloader.DataLoader object at 0x7f8b812fd840>, <torch.utils.data.dataloader.DataLoader object at 0x7f8b812fd990>]


In [7]:
server.send_model()

In [11]:
# server.clients[0].model

In [12]:
server.clients[0].dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f8b812fd450>

In [13]:
server.clients[0].client_update()

Epoch 1/5: 0it [00:00, ?it/s]     | 0/5 [00:00<?, ?it/s]
Epoch 2/5: 0it [00:00, ?it/s]     | 0/5 [00:00<?, ?it/s, epoch=0]
Epoch 3/5: 0it [00:00, ?it/s]     | 0/5 [00:00<?, ?it/s, epoch=1]
Epoch 4/5: 0it [00:00, ?it/s]     | 0/5 [00:00<?, ?it/s, epoch=2]
Epoch 5/5: 0it [00:00, ?it/s]     | 0/5 [00:00<?, ?it/s, epoch=3]
Training client 1: 100%|██████████| 5/5 [00:00<00:00, 254.39it/s, epoch=4]

last learning rate:  0.0075





In [15]:
for idx, data in enumerate(server.clients[0].dataloader):
    images = data[0]
    print(images)
    image = images[0][0]
    views = images[1]


In [23]:
server.clients[0].dataloader

<torch.utils.data.dataloader.DataLoader at 0x7f8b812fd450>

In [11]:
class ZenseactSSLDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __getitem__(self, idx):
        dat = self.data[idx]
        if self.transform:
            dat = self.transform[dat]
        dummy_label = 0
        return dat, dummy_label

    def __len__(self):
        return len(self.data)
    
class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, num_views):
        self.base_transform = base_transform
        self.num_views = num_views

    def __call__(self, x):
        q = [self.base_transform(x)]
        views = []
        for view in range(self.num_views):
            view = self.base_transform(x)
            views.append(view)
        # q = q.extend(views)
        return [q, views]

In [9]:
def generate_ssl_data(size=50_000):
    parent_directory = "../../../mnt/nfs_mount/single_frames"

    # Find all folders with a 6-digit name
    folder_pattern = os.path.join(parent_directory, "[0-9]" * 6)

    # Get the list of matching folders
    folders = glob.glob(folder_pattern)
    folders = folders[0:size] # *size* folders

    image_data = []
    for folder in folders:
        id = os.path.basename(folder) # id = foldername
        # load image
        image_path = f"../../../mnt/nfs_mount/single_frames/{id}/camera_front_blur/"
        image_path = glob.glob(image_path + "*.jpg")
        image = Image.open(image_path[0]).convert('RGB')
        # resize image
        downsampled_image = image.resize(RESCALE_SIZE)

        image_data.append(downsampled_image)
    return image_data

[<PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=120x100>,
 <PIL.Image.Image image mode=RGB size=12

In [12]:
ZenseactSSLDataset(image_data)

<__main__.ZenseactSSLDataset at 0x7f0e07e6f190>

In [None]:
class ZenseactDataset(Dataset):
    def __init__(self, size=50_000, transform=None):
        self.image_paths = []
        self.transform = transform

        # The parent directory containing the folders
        parent_directory = "../../../mnt/nfs_mount/single_frames"

        # Find all folders with a 6-digit name
        folder_pattern = os.path.join(parent_directory, "[0-9]" * 6)

        # Get the list of matching folders
        folders = glob.glob(folder_pattern)
        self.folders = folders[0:size] # *size* folders

        for folder in self.folders:
            id = os.path.basename(folder) # id = foldername
            # load image
            image_path = f"../../../mnt/nfs_mount/single_frames/{id}/camera_front_blur/"
            image_path = glob.glob(image_path + "*.jpg")
            self.image_paths.append(image_path)


    def __len__(self):
        return len(self.folders)


    def __getitem__(self, idx):
        image_path = self.image_paths[idx][0]
        # load image
        image = Image.open(image_path).convert('RGB')
        # resize image
        downsampled_image = image.resize(RESCALE_SIZE)

        if self.transform:
            downsampled_image = self.transform(downsampled_image)

        dummy_label = 0        
        return downsampled_image, dummy_label
    

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, num_views):
        self.base_transform = base_transform
        self.num_views = num_views

    def __call__(self, x):
        q = [self.base_transform(x)]
        views = []
        for view in range(self.num_views):
            view = self.base_transform(x)
            views.append(view)
        # q = q.extend(views)
        return [q, views]