In [1]:
import json
from PIL import Image
import os, glob
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset
import pickle

RESCALE_SIZE = (120, 100)


class ZenseactSSLDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.roadcondition = None

    def __getitem__(self, idx):
        dat = self.data[idx]
        if self.transform:
            dat = self.transform(dat)
        dummy_label = 0
        return dat, dummy_label

    def __len__(self):
        return len(self.data)
    

class ZenseactMetadata(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        meta_dat = self.data[idx]
        return meta_dat

    def __len__(self):
        return len(self.data)
    

class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, num_views):
        self.base_transform = base_transform
        self.num_views = num_views

    def __call__(self, x):
        q = [self.base_transform(x)]
        views = []
        for view in range(self.num_views):
            view = self.base_transform(x)
            views.append(view)
        return [q, views]
    

def generate_ssl_data(size=25_000):
    parent_directory = "../../../mnt/nfs_mount/single_frames"

    # Find all folders with a 6-digit name
    folder_pattern = os.path.join(parent_directory, "[0-9]" * 6)

    # Get the list of matching folders
    folders = glob.glob(folder_pattern)
    folders = folders[0:size] # *size* folders

    image_data = []
    for folder in folders:
        id = os.path.basename(folder) # id = foldername
        # load image
        image_path = f"../../../mnt/nfs_mount/single_frames/{id}/camera_front_blur/"
        image_path = glob.glob(image_path + "*.jpg")
        image = Image.open(image_path[0]).convert('RGB')
        # resize image
        downsampled_image = image.resize(RESCALE_SIZE)

        image_data.append(downsampled_image)
        
    return image_data


def generate_noniid_ssl_data(size=25_000):
    parent_directory = "../../../mnt/nfs_mount/single_frames"

    # Find all folders with a 6-digit name
    folder_pattern = os.path.join(parent_directory, "[0-9]" * 6)

    # Get the list of matching folders
    folders = glob.glob(folder_pattern)
    folders = folders[0:size] # *size* folders

    meta_data = []
    for folder in folders:
        id = os.path.basename(folder) # id = foldername
        # metainformation
        metadata = f"../../../mnt/nfs_mount/single_frames/{id}/metadata.json"
        f = open(metadata)
        metadata = json.load(f)
        weather_condition = metadata["scraped_weather"]
        meta_data.append(weather_condition)

    return meta_data

In [2]:
meta_data = generate_noniid_ssl_data(25_000)

In [3]:
sum_dict = {}
for data in meta_data:
    if data not in sum_dict:
        sum_dict[data] = 1
    else:
        sum_dict[data] += 1

In [4]:
sum_dict

{'partly-cloudy-day': 8535,
 'cloudy': 5099,
 'clear-day': 3737,
 'rain': 4491,
 'clear-night': 554,
 'snow': 436,
 'partly-cloudy-night': 1874,
 'wind': 36,
 'fog': 238}

In [5]:
if os.path.exists("ssl_data.pkl"):
        print("data file found")
        with open("ssl_data.pkl", "rb") as file:
            data = pickle.load(file)
else:
    print("generating data. This might take a while.")
    data = generate_ssl_data(25_000)
    with open("ssl_data.pkl", "wb") as file:
        pickle.dump(data, file)

data file found


In [6]:
zenseactmetadata = ZenseactMetadata(meta_data)
zenseactssldataset = ZenseactSSLDataset(data)

In [7]:
num_clients = 10
num_data_per_client = len(zenseactssldataset) // num_clients

batch_size = 32

In [8]:
client_indices = {}

for i in range(num_clients):
    client_indices[str(i+1)] = []
client_indices

{'1': [],
 '2': [],
 '3': [],
 '4': [],
 '5': [],
 '6': [],
 '7': [],
 '8': [],
 '9': [],
 '10': []}

In [16]:
num_data_per_client

2500

In [24]:
def add_index(client_num, index):
    if client_num > num_clients:
        return
    elif len(client_indices[str(client_num)]) >= num_data_per_client:
        add_index(client_num+1, index)
    else:
        client_indices[str(client_num)].append(index)

for i in range(len(zenseactmetadata)):
    weather = zenseactmetadata[i]
    if weather == "partly-cloudy-day":
        add_index(1, i)
    elif weather == "cloudy":
        add_index(2, i)
    elif weather == "clear-day":
        add_index(3, i)
    elif weather == "rain":
        add_index(4, i)
    elif weather == "clear-night":
        add_index(5, i)
    elif weather == "snow":
        add_index(6, i)
    elif weather == "partly-cloudy-night":
        add_index(7, i)
    elif weather == "fog":
        add_index(8, i)
    else:
        add_index(9, i)


In [38]:
local_dataloaders = []
for client, indices in client_indices.items():
    local_datasets = Subset(zenseactssldataset, indices)
    local_dataloaders.append(DataLoader(local_datasets, batch_size=batch_size))

local_dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x7fc50183d090>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc50212b0a0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc5021299c0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc50212b700>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502129e70>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc50212b190>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502138cd0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502139ff0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502139000>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502138430>]

In [39]:
zenseactmetadata = ZenseactMetadata(meta_data)
zenseactssldataset = ZenseactSSLDataset(data)

def load_data_noniid(zenseactssldataset, zenseactmetadata, num_clients, batch_size):
    num_data_per_client = len(zenseactssldataset) // num_clients
    
    client_indices = {}
    for i in range(num_clients):
        client_indices[str(i+1)] = []

    def add_index(client_num, index):
        if client_num > num_clients:
            return
        elif len(client_indices[str(client_num)]) >= num_data_per_client:
            add_index(client_num+1, index)
        else:
            client_indices[str(client_num)].append(index)

    for i in range(len(zenseactmetadata)):
        weather = zenseactmetadata[i]
        if weather == "partly-cloudy-day":
            add_index(1, i)
        elif weather == "cloudy":
            add_index(2, i)
        elif weather == "clear-day":
            add_index(3, i)
        elif weather == "rain":
            add_index(4, i)
        elif weather == "clear-night":
            add_index(5, i)
        elif weather == "snow":
            add_index(6, i)
        elif weather == "partly-cloudy-night":
            add_index(7, i)
        elif weather == "fog":
            add_index(8, i)
        else:
            add_index(9, i)

    local_dataloaders = []
    for client, indices in client_indices.items():
        local_datasets = Subset(zenseactssldataset, indices)
        local_dataloaders.append(DataLoader(local_datasets, batch_size=batch_size))

    return local_dataloaders


In [40]:
load_data_noniid(zenseactssldataset, zenseactmetadata, 10, 32)

[<torch.utils.data.dataloader.DataLoader at 0x7fc5d67a73d0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc5021707c0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc5021735e0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502173730>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502173af0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502173c70>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc5021733d0>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502172f80>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc502172050>,
 <torch.utils.data.dataloader.DataLoader at 0x7fc5021720b0>]