In [117]:
!python --version

Python 3.7.12


In [118]:
!nvidia-smi

Mon Jan 17 18:54:35 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    39W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [119]:
%%writefile training_utils.py
import time
from typing import Optional

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

from centered_clip import decentralized_centered_clip
from evaluator import evaluate_accuracy, verify_equal_parameters

transform_augment = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

transform_deterministic = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])


def train_with_centerclip(config, device: torch.device, writer: Optional[SummaryWriter] = None, verbose: int = 0):
    rank, world_size = dist.get_rank(), dist.get_world_size()
    torch.manual_seed(config.GLOBAL_SEED)  # seed for init
    model = config.MODEL.to(device)

    torch.manual_seed(config.GLOBAL_SEED * world_size + rank)  # seed for minibatches
    if verbose:
        print(f'==> [worker {rank}] Preparing data..')

    transform_train = transform_augment if config.AUGMENT_DATA else transform_deterministic
    trainset = torchvision.datasets.MNIST(
        root='./data', train=True, download=False, transform=transform_train)
    testset = torchvision.datasets.MNIST(
        root='./data', train=False, download=False, transform=transform_deterministic)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=config.BATCH_SIZE_PER_WORKER, shuffle=True, num_workers=0)

    # optimizers and LR
    steps_per_global_epoch = int(len(trainset) / config.BATCH_SIZE_PER_WORKER / config.NUM_WORKERS)
    steps_per_local_epoch = int(len(trainset) / config.BATCH_SIZE_PER_WORKER)
    optimizer = torch.optim.SGD(model.parameters(), lr=config.BASE_LR, momentum=config.MOMENTUM,
                                nesterov=config.NESTEROV, weight_decay=config.WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.COSINE_T_MAX_RATE * config.MAX_EPOCHS_PER_WORKER * steps_per_local_epoch)

    participant_cls = config.BYZANTINE_PARTICIPANT if rank in config.BYZANTINE_IDS else config.BENIGN_PARTICIPANT
    participant = participant_cls(model, optimizer, scheduler)

    loss_history, acc_history = [], []
    total_steps = 0

    if hasattr(config, 'INITIAL_CHECKPOINT'):
        print(f'[*] Resuming from step {config.INITIAL_STEP}, state `{config.INITIAL_CHECKPOINT}`...')

        total_steps = config.INITIAL_STEP

        with open(config.INITIAL_CHECKPOINT, 'rb') as f:
            state = pickle.load(f)
        model.load_state_dict(state['model'])
        optimizer.load_state_dict(state['opt'])

        torch.cuda.synchronize()
        torch.cuda.empty_cache()

        with warnings.catch_warnings():
            warnings.simplefilter("ignore")

            for _ in range(total_steps):
                scheduler.step()

        print('[+] State loaded')

    dist.barrier()
    start_time = time.time()

    for epoch_i in range(config.MAX_EPOCHS_PER_WORKER):
        if verbose:
            print(f'==> [worker {rank}] Began epoch {epoch_i}..')
        train_loss, train_acc = 0, 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            model.train(True)
            optimizer.zero_grad()

            outputs = model(inputs)
            participant.compute_grads(inputs, outputs, targets)

            with torch.no_grad():
                loss = F.cross_entropy(outputs, targets)
                acc = torch.mean((torch.argmax(outputs, dim=-1) == targets).to(torch.float))
                del outputs

            with torch.no_grad():
                grads = [param.grad for param in model.parameters()]
                clipped_grads, clip_stats = decentralized_centered_clip(
                    grads, tau=config.CCLIP_TAU, n_iters=config.CCLIP_MAX_ITERS, eps=config.CCLIP_EPS)
                for grad, clipped in zip(grads, clipped_grads):
                    grad[...] = clipped

            optimizer.step()
            scheduler.step()

            max_metrics_tuple = torch.tensor([clip_stats.n_clipped, clip_stats.step_norm,
                                              clip_stats.num_steps, clip_stats.std])
            dist.all_reduce(max_metrics_tuple, op=dist.ReduceOp.MAX)

            metrics_tuple = torch.tensor([loss, acc, clip_stats.n_clipped, clip_stats.step_norm,
                                          clip_stats.num_steps, clip_stats.std])
            dist.all_reduce(metrics_tuple, op=dist.ReduceOp.SUM)
            metrics_tuple /= world_size
            loss, acc, clip_stats.n_clipped, clip_stats.step_norm, clip_stats.num_steps, clip_stats.std = list(
                metrics_tuple)

            loss_history.append(loss.item())
            acc_history.append(acc.item())

            if writer:
                writer.add_scalar('train/loss', loss.item(), global_step=total_steps)
                writer.add_scalar('train/acc', acc.item(), global_step=total_steps)
                writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], global_step=total_steps)
                writer.add_scalar('train/global_epoch', total_steps / steps_per_global_epoch, global_step=total_steps)
                writer.add_scalar('train/local_epoch',
                                  epoch_i + batch_idx * config.BATCH_SIZE_PER_WORKER / len(trainset),
                                  global_step=total_steps)

                writer.add_scalar('util/n_clipped', clip_stats.n_clipped, global_step=total_steps)
                writer.add_scalar('util/final_step_norm', clip_stats.step_norm, global_step=total_steps)
                writer.add_scalar('util/std', clip_stats.std, global_step=total_steps)
                writer.add_scalar('util/num_steps', clip_stats.num_steps, global_step=total_steps)
                writer.add_scalar('util/mean_vector_std', clip_stats.std, global_step=total_steps)

                writer.add_scalar('util/max_n_clipped', max_metrics_tuple[0], global_step=total_steps)
                writer.add_scalar('util/max_step_norm', max_metrics_tuple[1], global_step=total_steps)
                writer.add_scalar('util/max_num_steps', max_metrics_tuple[2], global_step=total_steps)
                writer.add_scalar('util/max_vector_std', max_metrics_tuple[3], global_step=total_steps)

            checkpoint_dump_steps = getattr(config, 'CHECKPOINT_DUMP_STEPS', [])
            if rank == 0 and total_steps in checkpoint_dump_steps:
                filename = f'state_step_{total_steps}_exp_{config.EXP_NAME}.pickle'
                state = {
                    'model': model.state_dict(),
                    'opt': opt.state_dict(),
                }
                with open(filename, 'wb') as f:
                    pickle.dump(state, f)
                print(f'[+] Saved checkpoint to {filename}')

            if total_steps % config.EVAL_EVERY == 0:
                checksum_match = verify_equal_parameters(model)
                if rank == 0:
                    val_acc = evaluate_accuracy(model, testset, config.EVAL_BATCH_SIZE)
                    if writer:
                        writer.add_scalar('test/accuracy', val_acc, global_step=total_steps)
                    if verbose:
                        print(
                            end=f'step {str(total_steps).rjust(5, "0")}\t| val accuracy = {val_acc:.5f}\t| training for {time.time() - start_time:.5f}s.\t| checksum ok = {checksum_match}\n')

                dist.barrier()
                if verbose >= 2:
                    print(
                        end=f"worker {str(rank).rjust(2, '0')}, step {total_steps}\t| loss: {np.mean(loss_history[-config.EVAL_EVERY:]):.5f},"
                            f" acc: {np.mean(acc_history[-config.EVAL_EVERY:]):.5f}\n")
                    if rank == 0:
                        print()

            total_steps += 1
            if total_steps >= getattr(config, 'EARLY_STOP_STEPS', float('inf')):
                if verbose:
                    print(f"worker {str(rank).rjust(2, '0')} stopping at {total_steps}")
                return model, optimizer
    return model, optimizer        
        

Overwriting training_utils.py


In [120]:
%%writefile evaluator.py
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter


@torch.no_grad()
def evaluate_accuracy(model: nn.Module, dataset: torch.utils.data.Dataset, batch_size: int, num_workers: int = 0):
    model.train(False)
    device = next(iter(model.parameters())).device
    acc_numerator = acc_denominator = 0
    for inputs, targets in torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers):
        inputs, targets = inputs.to(device), targets.to(device)
        acc_numerator += (model(inputs).argmax(-1) == targets).to(torch.float32).sum()
        acc_denominator += len(inputs)
    return acc_numerator / acc_denominator


@torch.no_grad()
def verify_equal_parameters(model):
    with torch.no_grad():
        checksum = sum([p.sum() for p in model.parameters()]).cpu()
        checksums = torch.randn(dist.get_world_size())
        dist.all_gather(list(checksums), checksum)
        return torch.allclose(checksums[:-1], checksums[1:])


Overwriting evaluator.py


In [121]:
%%writefile centered_clip.py
import dataclasses
from typing import Sequence, Tuple

import torch
import torch.distributed as dist


def split_into_parts(tensors: Sequence[torch.Tensor], num_parts: int) -> Tuple[torch.Tensor, ...]:
    """ combines averaged_tensors into one tensor and splits them into equal chunks of size group_size """
    total_size = sum(t.numel() for t in tensors)
    parts = list(map(torch.Tensor.flatten, tensors))
    if total_size % num_parts:
        parts.append(torch.zeros(num_parts - total_size % num_parts, device=parts[0].device))
    flat_tensor = torch.cat(parts)
    return torch.split(flat_tensor, len(flat_tensor) // num_parts, dim=0)


def restore_from_parts(chunks: Sequence[torch.Tensor], shapes: Sequence[torch.Size]) -> Tuple[torch.Tensor, ...]:
    """ restores the original tensor shapes from chunks obtained by split_into_chunks """
    result_sizes = tuple(map(torch.Size.numel, shapes))
    flat_tensor = torch.cat(tuple(chunks))[:sum(result_sizes)]
    flat_original_tensors = torch.split_with_sizes(flat_tensor, result_sizes)
    return tuple(map(torch.Tensor.reshape, flat_original_tensors, shapes))


@dataclasses.dataclass(frozen=False)
class CenteredClipResult:
    result: torch.Tensor
    n_clipped: torch.Tensor
    step_norm: torch.Tensor
    num_steps: torch.Tensor
    std: torch.Tensor


def centered_clip(input_tensors: torch.Tensor, weights: torch.Tensor,
                  tau: float, n_iters: int = 20, eps: float = 1e-6) -> CenteredClipResult:
    result_shape = input_tensors.shape[1:]
    input_tensors = input_tensors.flatten(start_dim=1)

    result = input_tensors.median(dim=0).values
    one = torch.tensor(1.0, device=result.device)

    for i in range(n_iters):
        diff = input_tensors - result
        coeffs = tau / diff.norm(dim=1)
        n_clipped = (coeffs < 1.0).sum()
        coeffs = weights * torch.min(one, coeffs)
        step = (diff * coeffs[:, None]).sum(dim=0) / weights.sum()
        result += step
        if step.norm() <= eps:
            break

    vector_std = torch.mean((input_tensors - input_tensors.mean(dim=0)).norm(dim=1) ** 2) ** 0.5
    return CenteredClipResult(result=result, n_clipped=n_clipped, step_norm=step.norm(),
                              num_steps=i, std=vector_std)


def decentralized_centered_clip(local_tensors, **kwargs):
    rank, world_size = dist.get_rank(), dist.get_world_size()
    tensor_parts = list(split_into_parts(local_tensors, num_parts=world_size))
    device = tensor_parts[0].device
    gathered_from_peers = torch.empty(world_size, len(tensor_parts[rank]), device=device)
    handles = []
    for j in range(world_size):
        handles.append(dist.scatter(
            gathered_from_peers[j], tensor_parts if rank == j else None, src=j, async_op=True))
    for handle in handles:
        handle.wait()

    clipped = centered_clip(gathered_from_peers, weights=torch.ones(world_size, device=device), **kwargs)

    dist.barrier()
    dist.all_gather(tensor_parts, clipped.result)
    return restore_from_parts(tensor_parts, [t.shape for t in local_tensors]), clipped


Overwriting centered_clip.py


# Attacks

In [122]:
%%writefile worker_types.py
import random

import torch
import torch.nn.functional as F


class NormalParticipant:
    def __init__(self, model, optimizer, scheduler):
        self.model, self.optimizer, self.scheduler = model, optimizer, scheduler

    def compute_grads(self, inputs, outputs, targets):
        loss = F.cross_entropy(outputs, targets)
        loss.backward()


class SignFlipper:
    def __init__(self, model, optimizer, scheduler, ban_prob: float, attack_start: int,
                 direction_seed: int = 0, attack_every: int = 1):
        self.model, self.optimizer, self.scheduler = model, optimizer, scheduler
        self.ban_prob, self.attack_start = ban_prob, attack_start
        self.num_steps, self.banned = 0, False
        self.attack_every = attack_every

    def __repr__(self):
        return f"{self.__class__.__name__}({self.ban_prob}, {self.attack_start}, {self.flip_scale})"

    def compute_grads(self, inputs, outputs, targets):
        loss = F.cross_entropy(outputs, targets)
        loss.backward()

        if self.num_steps > self.attack_start and self.num_steps % self.attack_every == 0 and not self.banned:
            print(end=f"ATTACK@{self.num_steps}\n")
            with torch.no_grad():
                for param in self.model.parameters():
                    param.grad *= -1000

            if random.random() < self.ban_prob:
                print(f"BANNED@{self.num_steps}\n")
                self.banned = True

        self.num_steps += 1


class LabelFlipper:
    def __init__(self, model, optimizer, scheduler, ban_prob: float, attack_start: int,
                 direction_seed: int = 0, attack_every: int = 1):
        self.model, self.optimizer, self.scheduler = model, optimizer, scheduler
        self.ban_prob, self.attack_start = ban_prob, attack_start
        self.num_steps, self.banned = 0, False
        self.attack_every = attack_every

    def __repr__(self):
        return f"{self.__class__.__name__}({self.ban_prob}, {self.attack_start}, {self.flip_scale})"

    def compute_grads(self, inputs, outputs, targets):
        if self.num_steps > self.attack_start and self.num_steps % self.attack_every == 0 and not self.banned:
            print(end=f"ATTACK@{self.num_steps}\n")

            loss = F.cross_entropy(outputs, 9 - targets)
            loss.backward()

            if random.random() < self.ban_prob:
                print(f"BANNED@{self.num_steps}\n")
                self.banned = True
        else:
            loss = F.cross_entropy(outputs, targets)
            loss.backward()

        self.num_steps += 1


class ConstantDirection:
    def __init__(self, model, optimizer, scheduler, ban_prob: float, attack_start: int,
                 direction_seed: int = 0, attack_every: int = 1):
        self.model, self.optimizer, self.scheduler = model, optimizer, scheduler
        self.ban_prob, self.attack_start = ban_prob, attack_start
        self.num_steps, self.banned = 0, False
        self.direction_seed = direction_seed
        self.attack_every = attack_every

    def __repr__(self):
        return f"{self.__class__.__name__}({self.ban_prob}, {self.attack_start}, {self.flip_scale})"

    def compute_grads(self, inputs, outputs, targets):
        loss = F.cross_entropy(outputs, targets)
        loss.backward()

        if self.num_steps > self.attack_start and self.num_steps % self.attack_every == 0 and not self.banned:
            print(end=f"ATTACK@{self.num_steps}\n")
            grad_devices = {param.grad.device for param in self.model.parameters()}
            with torch.no_grad(), torch.random.fork_rng(grad_devices):
                torch.manual_seed(self.direction_seed)
                for param in self.model.parameters():
                    rand_buf = torch.randn_like(param.grad)
                    param.grad[...] = rand_buf * (1 / rand_buf.norm() * param.grad.norm() * 1000)

            if random.random() < self.ban_prob:
                print(f"BANNED@{self.num_steps}\n")
                self.banned = True

        self.num_steps += 1


class LabelShuffler:
    def __init__(self, model, optimizer, scheduler, ban_prob: float, attack_start: int,
                 direction_seed: int = 0, attack_every: int = 1):
        self.model, self.optimizer, self.scheduler = model, optimizer, scheduler
        self.ban_prob, self.attack_start = ban_prob, attack_start
        self.num_steps, self.banned = 0, False
        self.attack_every = attack_every

    def __repr__(self):
        return f"{self.__class__.__name__}({self.ban_prob}, {self.attack_start}, {self.flip_scale})"

    def compute_grads(self, inputs, outputs, targets):
        if self.num_steps > self.attack_start and self.num_steps % self.attack_every == 0 and not self.banned:
            print(end=f"ATTACK@{self.num_steps}\n")

            shuffler = torch.randperm(targets.shape[0])
            shuffled_targets = targets[shuffler].view(targets.size())
            loss = F.cross_entropy(outputs, shuffled_targets)
            loss.backward()

            if random.random() < self.ban_prob:
                print(f"BANNED@{self.num_steps}\n")
                self.banned = True
        else:
            loss = F.cross_entropy(outputs, targets)
            loss.backward()

        self.num_steps += 1


class GaussianNoiseAdder:
    def __init__(self, model, optimizer, scheduler, ban_prob: float, attack_start: int,
                 direction_seed: int = 0, attack_every: int = 1):
        self.model, self.optimizer, self.scheduler = model, optimizer, scheduler
        self.ban_prob, self.attack_start = ban_prob, attack_start
        self.num_steps, self.banned = 0, False
        self.attack_every = attack_every

    def __repr__(self):
        return f"{self.__class__.__name__}({self.ban_prob}, {self.attack_start}, {self.flip_scale})"

    def compute_grads(self, inputs, outputs, targets):
        loss = F.cross_entropy(outputs, targets)
        loss.backward()

        if self.num_steps > self.attack_start and self.num_steps % self.attack_every == 0 and not self.banned:
            print(end=f"ATTACK@{self.num_steps}\n")
            with torch.no_grad():
                param_index = 0
                for param in self.model.parameters():
                    # Perturb only every other parameter.
                    if param_index % 2 == 0:
                        param.grad = random.normalvariate(param.grad, 100)

                    param_index += 1

            if random.random() < self.ban_prob:
                print(f"BANNED@{self.num_steps}\n")
                self.banned = True

        self.num_steps += 1


Overwriting worker_types.py


In [123]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [124]:
%%writefile model.py

import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
        x = nn.functional.relu(nn.functional.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.dropout(x, training=self.training)
        x = self.fc2(x)
        return nn.functional.log_softmax(x, dim=1)


Overwriting model.py


In [125]:
%%writefile config.py
import random
from functools import partial

import model
from worker_types import NormalParticipant


class Config:
    MODEL = model.Net()
    AUGMENT_DATA = False

    GLOBAL_SEED = 0
    NUM_WORKERS = 6
    MAX_EPOCHS_PER_WORKER = 4
    BATCH_SIZE_PER_WORKER = 256
    EVAL_BATCH_SIZE = 32
    EVAL_EVERY = 50

    CCLIP_TAU = 1.0
    CCLIP_MAX_ITERS = 500
    CCLIP_EPS = 1e-6

    BASE_LR = 0.05
    MOMENTUM = 0.9
    NESTEROV = True
    WEIGHT_DECAY = 5e-4
    COSINE_T_MAX_RATE = 1.0

    BENIGN_PARTICIPANT = NormalParticipant
    ATTACK = NormalParticipant
    NAME = 'baseline'
    NUM_BYZANTINES = 0
    ATTACK_EVERY = 0
    ATTACK_START = 100

    NUM_BYZANTINES = 0 if NAME == 'baseline' else NUM_BYZANTINES
    BYZANTINE_PARTICIPANT = partial(
        ATTACK, ban_prob=1. / NUM_WORKERS * ((NUM_WORKERS - NUM_BYZANTINES) / NUM_WORKERS),
        attack_start=100, direction_seed=GLOBAL_SEED, attack_every=ATTACK_EVERY)
    BYZANTINE_IDS = random.Random(GLOBAL_SEED).sample(range(1, NUM_WORKERS), NUM_BYZANTINES)

    EARLY_STOP_STEPS = 5_100 if ATTACK_START == 100 else 14_100
    EXP_NAME = f"mnist_decentclip_tau{CCLIP_TAU}_max_iters{CCLIP_MAX_ITERS}_seed{GLOBAL_SEED}"

    def __init__(self, num_byzantines, attack_start, attack_every, global_seed, tau, attack, name):
        self.NUM_BYZANTINES = num_byzantines
        self.ATTACK_START = attack_start
        self.ATTACK_EVERY = attack_every
        self.GLOBAL_SEED = global_seed
        self.NAME = name
        self.NUM_BYZANTINES = 0 if name == 'baseline' else num_byzantines
        self.ATTACK = attack
        self.ATTACK_EVERY = attack_every
        self.ATTACK_START = attack_start
        self.CCLIP_TAU = tau
        self.BYZANTINE_PARTICIPANT = partial(self.ATTACK, ban_prob=1. / self.NUM_WORKERS * (
                    (self.NUM_WORKERS - self.NUM_BYZANTINES) / self.NUM_WORKERS), attack_start=100,
                                             direction_seed=self.GLOBAL_SEED, attack_every=self.ATTACK_EVERY)
        self.BYZANTINE_IDS = random.Random(self.GLOBAL_SEED).sample(range(1, self.NUM_WORKERS), self.NUM_BYZANTINES)


Overwriting config.py


In [126]:
%%writefile experiments.py 
import os
import socket
import sys
import time
from contextlib import closing
from contextlib import contextmanager

import torch
import torch.distributed as dist
import torchvision
from torch.multiprocessing import Process
from torch.utils.tensorboard import SummaryWriter

import worker_types
from config import Config
from training_utils import train_with_centerclip


def find_free_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return str(s.getsockname()[1])


@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


def run_worker(rank, config, LOGS_FOLDER):
    print("worker created", rank)
    backend = 'gloo'
    print("value of rank", rank)
    dist.init_process_group(backend, init_method='env://', rank=rank, world_size=config.NUM_WORKERS)
  
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    if rank == 0:
        writer = SummaryWriter('./{}/{}_rank{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(
            LOGS_FOLDER, config.EXP_NAME, rank, *time.gmtime()[:6]))
        writer.add_text('config', '\n'.join(f'{k}: {v}' for k, v in config.__dict__.items()
                                            if not k.startswith('_')))
    else:
        writer = None

    model, opt = train_with_centerclip(config, device, writer, verbose=1)
    print(f'Complete worker {rank}\'s run')


if __name__ == "__main__":
    torch.set_num_threads(1)

    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = find_free_port()

    with suppress_stdout():
        torchvision.datasets.MNIST(root='./data', train=True, download=True)
        torchvision.datasets.MNIST(root='./data', train=False, download=True)

    for seed in range(3):
        for num_byzantines, attack_every in [(2, 1), (1, 5)]:
            tau = 5.0
            attack_start = 100
            for name, attack in [
                ('baseline', worker_types.NormalParticipant),
                ('gaussiannoise', worker_types.GaussianNoiseAdder),
                ('shuffle', worker_types.LabelShuffler),
                ('signflip', worker_types.SignFlipper),
                ('labelflip', worker_types.LabelFlipper),
                ('constantdirection', worker_types.ConstantDirection),
            ]:
                print(
                    f'Seed = {seed}; Num_byzantines = {num_byzantines}; Attack_every = {attack_every}; Tau = {tau}')
                print(f'ATTACK = {name}')

                LOGS_FOLDER = f'{name}_at_{attack_start}_attackers_{num_byzantines}_every_{attack_every}'
                print(name, attack, LOGS_FOLDER)
                config = Config(num_byzantines, attack_start, attack_every, seed, tau, attack, name)

                processes = []
                for rank in range(0, config.NUM_WORKERS):
                    p = Process(target=run_worker, args=(rank, config, LOGS_FOLDER), daemon=True)
                    p.start()
                    processes.append(p)
                    # mp.spawn(run_worker, args=(config, LOGS_FOLDER, ), daemon=True, nprocs=config.NUM_WORKERS)

                for p in processes:
                    p.join()


Overwriting experiments.py


In [127]:
!python experiments.py

Seed = 0; Num_byzantines = 2; Attack_every = 1; Tau = 5.0
ATTACK = baseline
baseline <class 'worker_types.NormalParticipant'> baseline_at_100_attackers_2_every_1
worker created 0
value of rank 0
worker created 1
value of rank 1
worker created 2
value of rank 2
worker created 3
value of rank 3
worker created 4
value of rank 4
worker created 5
value of rank 5
==> [worker 2] Preparing data..
==> [worker 3] Preparing data..
==> [worker 4] Preparing data..
==> [worker 1] Preparing data..
==> [worker 5] Preparing data..
==> [worker 0] Preparing data..
==> [worker 3] Began epoch 0..
==> [worker 1] Began epoch 0..
==> [worker 5] Began epoch 0..
==> [worker 2] Began epoch 0..
==> [worker 4] Began epoch 0..
==> [worker 0] Began epoch 0..
step 00000	| val accuracy = 0.09420	| training for 1.50494s.	| checksum ok = True
step 00050	| val accuracy = 0.92690	| training for 8.09749s.	| checksum ok = True
step 00100	| val accuracy = 0.95820	| training for 14.63745s.	| checksum ok = True
step 00150	| va

In [128]:
%reload_ext tensorboard

In [None]:
tensorboard --logdir .

# Save and download the experiment results

In [None]:
!zip -r ./experiment_results.zip ./ 

In [None]:
from google.colab import files
files.download("experiment_results.zip")

# Visualisation of the model

In [115]:
!pip install torchviz



In [116]:
import torch
from torchviz import make_dot
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

import model 

transform_deterministic = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

writer = SummaryWriter('runs/plotter')
net = model.Net()
net.train()

transform_train = transform_deterministic
trainset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform_deterministic)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=0)

dataiter = iter(trainloader)
images, labels = dataiter.next()

yhat = net(images)
make_dot(yhat, params=dict(list(net.named_parameters()))).render("cnn_torchviz", format="png")

input_names = ['input']
output_names = ['output']
torch.onnx.export(net, images, 'rcn.onnx', input_names=input_names, output_names=output_names)

Note: Use Netron to view the ONNX file.