In [None]:
from handover_grasping import HANet
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
def train(net, trainloader, epochs):
    """Train the model on the training set."""
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr = 1e-3)
    for _ in range(epochs):
        for i_batch, sampled_batched in enumerate(trainloader):
            optimizer.zero_grad()
            color = sampled_batched['color'].cuda()
            depth = sampled_batched['depth'].cuda()
            label = sampled_batched['label'].permute(0,2,3,1).cuda().float()
            criterion(net(color, depth), label).backward()
            optimizer.step()


def test(net, testloader):
    """Validate the model on the test set."""
    criterion = torch.nn.BCEWithLogitsLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for i_batch, sampled_batched in enumerate(testloader):
            color = sampled_batched['color'].cuda()
            depth = sampled_batched['depth'].cuda()
            labels = sampled_batched['label'].permute(0,2,3,1).cuda().float()
            
            outputs = net(color, depth)
            loss += criterion(outputs, labels).item()
            
    return loss / len(testloader.dataset)


def load_data():
    DATA_PATH = '/home/arg/handover_grasping/data/HANet_training_datasets'

    dataset_train = handover_grasping_dataset(DATA_PATH, color_type='png')
    dataset_test = handover_grasping_dataset(DATA_PATH, color_type='png', mode='fl_test')

    return DataLoader(dataset_train, batch_size = 8, shuffle = True, num_workers = 8), DataLoader(dataset_test, batch_size = 1, shuffle = False, num_workers = 8)

In [None]:
net = HANet(4)
net = net.cuda()
trainloader, testloader = load_data()

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss = test(net, testloader)
        return 0.0, 0.0

In [None]:
TARGET_SERVER_IP = 'your_own_IP'

In [None]:
fl.client.start_numpy_client(TARGET_SERVER_IP+":8080", client=FlowerClient())