In [2]:
import numpy as np
import os
import torch
import torch.distributed as dist
import torch.optim as optim
from torch.multiprocessing import Process
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from random import Random

WORLD_SIZE = 2
NUM_EPOCHS = 50
TRAINING_RECORD_INTERVAL = 25
BATCH_SIZE = 128

In [None]:
class Partion(object):

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(244),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

class DataPartitioner(object):

    def __init__(self, data, seed = 8675309):
        self.data = data
        self.partitions = []
        
        # todo generalize this
        p1_indexes, p2_indexes = [], []
        for idx in range(len(data)):
            if data[idx][1] < 5:
                p1_indexes.append(idx)
            else:
                p2_indexes.append(idx)

        rng = Random()
        rng.seed(seed)
        rng.shuffle(p1_indexes)
        rng.shuffle(p2_indexes)

        self.partitions.append(p1_indexes)
        self.partitions.append(p2_indexes)

    def use(self, partion):
        return Partion(self.data, self.partitions[partion])

def partition_dataset(world_size, train = True):

    dataset = datasets.CIFAR10('./data', download = True, 
       train = train, 
       transform = (transform_train if train else transform_test))

    partition = DataPartitioner(dataset)
    partition = partition.use(dist.get_rank())
    train_set = torch.utils.data.DataLoader(partition, 
                                        batch_size = 64, 
                                        shuffle = True)
    return train_set, 64

def average_gradients(model, rank):
    size = float(dist.get_world_size())
    for param in model.parameters():

        other_model_param = torch.zeros(1)
        # pass to rank 1
        if rank == 0:
            dist.send(param.grad.data, dst = 1)
        else:
            dist.recv(tenosr = other_model_param, src = 0)

        # pass to rank 0
        if rank == 0:
            dist.recv(tensor = other_model_param, src = 1)
        else:
            dist.send(param.grad.data, dst = 1)

        param.grad.data = torch.log(param.grad.data) - torch.log(1 - param.grad.data)

In [None]:
def run(rank, size, model, criterion, optimizer):
    
    torch.manual_seed(8675309)
    training_set, bsz = partition_dataset(WORLD_SIZE)
    testing_set, bsz = partition_dataset(WORLD_SIZE, train = False)

    model.cuda()
    criterion = criterion.cuda()

    # Set up record holder and testing set for only the master node
    if rank == 0:
        training_accuracy = []
        testing_accuracy = []
        
        imagenet_data_test = datasets.CIFAR10('./data', download = True, train = False)
        testing_size = len(imagenet_data_test)
        del imagenet_data_test
        

    for epoch_idx in range(NUM_EPOCHS):

        for batch_idx, (data, target) in enumerate(training_set):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            outputs = model(data)
            
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == target).sum().item()
            loss = criterion(outputs, target)

            if batch_idx % TRAINING_RECORD_INTERVAL == 0:
              print('Rank %d\tEpoch: %d\tIterval: %d\tAccuracy : %d%%' % (
                  rank,
                  epoch_idx,
                  batch_idx,
                  100 * correct / (BATCH_SIZE / WORLD_SIZE)))

            loss.backward()
        
            if rank == 0:
              if batch_idx % TRAINING_RECORD_INTERVAL == 0:
                training_accuracy.append(100 * correct / (BATCH_SIZE / WORLD_SIZE))

            average_gradients(model, rank)
            optimizer.step()
        
        # After each epoch record testing results on master node
        testing_correct = 0
        for idx, (inputs, labels) in enumerate(testing_set):

            inputs, labels = inputs.cuda(), labels.cuda()
            outputs = model(inputs)

            _, predicted = torch.max(outputs.data, 1)
            testing_correct += (predicted == labels).sum().item()
    
        testing_correct = torch.tensor(testing_correct).cuda()
        # wait till all processes have finished the epoch
        dist.barrier()

        recv = dist.all_reduce(testing_correct, op = dist.reduce_op.SUM)
            
        if rank == 0:
            print('Epoch: %d\tAccuracy: %d %%' % (epoch_idx, 100 * testing_correct.data.item() / testing_size))
            testing_accuracy.append(100 * testing_correct / testing_size)
        
        dist.barrier()


    if rank == 0:
        np.save('/content/drive/My Drive/Colab Notebooks/Results/GAN Style Training/training_accuracy.npy', training_accuracy)
        np.save('/content/drive/My Drive/Colab Notebooks/Results/GAN Style Training/testing_accuracy.npy', testing_accuracy)
        #torch.save(model.state_dict(), '/content/drive/My Drive/Colab Notebooks/Results/Parallel Control/model_control.pt')

In [None]:
def init_process(rank, size, model, criterion, optimizer, fn, backend = "nccl"):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank = rank, world_size = size)
    fn(rank, size, model, criterion, optimizer)

In [None]:
processes = []
model = models.alexnet(num_classes = 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.0005)

for rank in range(WORLD_SIZE):
    p = Process(target = init_process, args = (rank, WORLD_SIZE, model, criterion, optimizer, run))
    p.start()
    processes.append(p)

for p in processes:
    p.join()

print("Execution Finished")