In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist 
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
import numpy as np 
import os 
import random 
import time 
import logging 
from datetime import timedelta
import pandas as pd

In [2]:
random.seed(43)
np.random.seed(43)
torch.manual_seed(43)

<torch._C.Generator at 0x7194cc0bbd70>

In [3]:
# ------------- Init / Communication Utilities -------------

def init(rank, world_size, backend='gloo'):
    os.environ['GLOO_SOCKET_IFNAME'] = 'eth0'
    os.environ['MASTER_ADDR'] = 'client1'
    os.environ['MASTER_PORT'] = '29500'
    
    dist.init_process_group(
            backend=backend,
            world_size=world_size,
            rank=rank,
            timeout=timedelta(seconds=60),
        )
    
    print(f'Rank {rank} initialized and ready.')
    return dist.is_initialized()

def recv(arr, src):
    dist.recv(tensor=arr, src=src)

def snd(arr, dst):
    dist.send(tensor=arr, dst=dst)

def terminate(rank):
    dist.destroy_process_group()
    print(f'Rank {rank} successfully terminated.')

def send_model(model, dst):
    for key, param in model.state_dict().items():
        dist.send(param.data, dst=dst)
        print(f"Sent {key}")

def recv_model(model, src):
    for key, param in model.state_dict().items():
        dist.recv(param.data, src=src)
        print(f"Received {key}")

In [4]:
# ------------------ Dataset ------------------

class DatasetC2(Dataset):
    def __init__(self, transform=None, path='/SFLVL/iris_train_c1.npy'):
        super().__init__()
        self.data = np.load(path)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x=self.data[index]
        if self.transform:
            x = self.transform(x) 
        return  x

In [5]:
# ------------------ Transform ------------------

class ToTensor:
    def __call__(self, input):
        return torch.from_numpy(input).float()

In [6]:
# ------------------ Model ------------------

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_layer = nn.Linear(in_features=2, out_features=8)
        self.h1 = nn.Linear(in_features=8,out_features=6)
        self.h2 = nn.Linear(in_features=6,out_features=4)
    
    def forward(self, input):
        ac1= F.relu(self.input_layer(input))
        ac2= F.relu(self.h1(ac1))
        return F.relu(self.h2(ac2))

In [7]:
# ------------------ Training Loop ------------------

def run(num_epoch,model):
    batch_size = 5
    transform = ToTensor()
    dataset = DatasetC2(transform=transform, path='/SFVL/iris_c1.npy')
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    
    
    model.train()
    
    optim = torch.optim.Adam(params=model.parameters(), lr=0.01)
    
    for epoch in range(num_epoch):
        for idx, (input_batch) in enumerate(dataloader):
            # Forward pass on client (bottom) model
            smashed_data = model(input_batch)
            print(smashed_data.size())

            # Send smashed data to server
            snd(smashed_data, dst=2)

            # Receive gradient from server
            gradient = torch.zeros_like(smashed_data)
            recv(gradient, src=2)

            # Backpropagation and optimizer step
            optim.zero_grad()
            smashed_data.backward(gradient)
            optim.step()
            # print(f"Epoch {epoch+1}, Batch {idx+1} processed")


In [8]:
def evaluation(model):
    model.eval()
    
    # Load test data
    input = torch.from_numpy(np.load('/SFVL/iris_test_c1.npy')).float()
    print(input[:5])
    
    smashed_data = model(input)
    
    snd(smashed_data,dst=2)
    print(smashed_data.size())
    print(f'smashed data sent successfully...')
        


In [9]:
rank = 0 
world_size = 3
model = Model()
init(rank=rank,world_size=world_size)

Rank 0 initialized and ready.


True

In [10]:
num_epoch = 100
run(num_epoch=num_epoch,model=model)
evaluation(model=model)

torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([5, 4])
torch.Size([

In [11]:
terminate(rank=rank)

Rank 0 successfully terminated.
