In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist 
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
import numpy as np 
import os 
import random 
import time 
import logging 
from sklearn.metrics import precision_score, recall_score, f1_score,accuracy_score
from datetime import timedelta
import pandas as pd

In [2]:
random.seed(43)
np.random.seed(43)
torch.manual_seed(43)
file_name = "/SFVL/iris_trial.log"
logging.basicConfig(filename=file_name,
                    level=logging.INFO,
                    format='%(message)s')

In [3]:
# ------------- Init / Communication Utilities -------------

def init(rank, world_size, backend='gloo'):
    os.environ['GLOO_SOCKET_IFNAME'] = 'eth0'
    os.environ['MASTER_ADDR'] = 'client1'
    os.environ['MASTER_PORT'] = '29500'
    
    dist.init_process_group(backend=backend,
                            rank=rank,
                            world_size=world_size,
                            timeout=timedelta(seconds=60))
    
    print(f'Rank {rank} initialized and ready.')
    return dist.is_initialized()

def recv(arr, src):
    dist.recv(tensor=arr, src=src)

def snd(arr, dst):
    dist.send(tensor=arr.contiguous(), dst=dst)

def terminate(rank):
    dist.destroy_process_group()
    print(f'Rank {rank} successfully terminated.')

def send_model(model, dst):
    for key, param in model.state_dict().items():
        dist.send(param.data, dst=dst)
        print(f"Sent {key}")

def recv_model(model, src):
    for key, param in model.state_dict().items():
        dist.recv(param.data, src=src)
        print(f"Received {key}")

In [4]:
# ------------------ Dataset ------------------

class DatasetServer(Dataset):
    def __init__(self, transform=None, path='/SFLVL/iris_train.npy'):
        super().__init__()
        self.data = np.load(path)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        y = self.data[index]
        if self.transform:
            y = self.transform(y) 
        return  y

In [5]:
class ToTensor:
    def __call__(self, input):
        return torch.tensor(input, dtype=torch.long)


In [6]:
class IrisNN(nn.Module):
    def __init__(self):
        super(IrisNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(16, 12),  
            nn.ReLU(),
            nn.Linear(12, 8),   
            nn.ReLU(),
            nn.Linear(8, 4),    
            nn.ReLU(),
            nn.Linear(4, 3)     
        )

    def forward(self, x):
        return self.net(x)

In [7]:
# ------------------ Training Loop ------------------

def run(num_epoch,model):
    batch_size = 5
    transform = ToTensor()
    dataset = DatasetServer(transform=transform, path='/SFVL/iris_train.npy')
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    
    model.train()
    criterion = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(num_epoch):
        for idx, (target_batch) in enumerate(dataloader):
            optim.zero_grad()
            smashed_data_c1 = torch.zeros((batch_size,8),dtype=torch.float32)
            smashed_data_c2 = torch.zeros_like(smashed_data_c1)
            recv(smashed_data_c1,src=0)
            recv(smashed_data_c2,src=1)
            
            input = torch.cat((smashed_data_c1,smashed_data_c2),dim=1)
            input.requires_grad_()
            # print(input.is_leaf)
            logits = model(input)
            # print(f'logits shape : {logits.size()}')
            # print(f'target shape : {target_batch.size()} {target_batch[0]}')
            loss = criterion(logits, target_batch)
            
            loss.backward()
            optim.step()
            gradient = input.grad
            
            grad_c1 = gradient[:,:8]
            grad_c2 = gradient[:,8:]
            # print(grad_c1.size())
            snd(grad_c1,dst=0)
            snd(grad_c2,dst=1)
            print(f"Epoch {epoch+1}, Batch {idx+1} processed loss : {loss.view(-1)}")
        if epoch == num_epoch-1:
            logging.info(f'training loss : {loss}')


In [8]:
def evaluation(model):
    model.eval()
    target = torch.from_numpy(np.load('/SFVL/iris_test.npy'))
    print(target[:5])
    len_dataset = target.size(0)
    
    # Receive smashed data from clients
    smashed_data_c1 = torch.zeros((len_dataset, 8), dtype=torch.float32)
    smashed_data_c2 = torch.zeros_like(smashed_data_c1)
    print(smashed_data_c1.size())
    recv(smashed_data_c1, src=0)
    recv(smashed_data_c2, src=1)
    
    # Forward pass
    input = torch.cat((smashed_data_c1, smashed_data_c2), dim=1)
    logits = model(input)
    softmax = nn.Softmax(dim=1)
    yhat = softmax(logits)
    print(yhat[:5])
    
    # Predictions and ground truth
    
    preds = torch.argmax(yhat, dim=1)
    print(preds[:5])
    true = target
    
    # Convert to NumPy for sklearn
    preds_np = preds.cpu().numpy()
    true_np = true.cpu().numpy()
    
    # Compute metrics
    accuracy = accuracy_score(true_np,preds_np)
    precision = precision_score(true_np, preds_np, average='macro')
    recall = recall_score(true_np, preds_np, average='macro')
    f1 = f1_score(true_np, preds_np, average='macro')
    
    # Log results
    print(f"\n--- Evaluation Metrics ---")
    print(f"Accuracy : {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall   : {recall:.4f}")
    print(f"F1 Score : {f1:.4f}")
    print(f"--------------------------\n")

    logging.info(f"Accuracy : {accuracy:.4f}")
    logging.info(f"Precision: {precision:.4f}")
    logging.info(f"Recall   : {recall:.4f}")
    logging.info(f"F1 Score : {f1:.4f}")
    

In [9]:
rank = 2 
world_size = 3

init(rank=rank,world_size=world_size)
model = IrisNN()

Rank 2 initialized and ready.


In [10]:
num_epoch = 100
run(num_epoch=num_epoch,model=model)
evaluation(model=model)

Epoch 1, Batch 1 processed loss : tensor([1.2452], grad_fn=<ViewBackward0>)
Epoch 1, Batch 2 processed loss : tensor([1.0399], grad_fn=<ViewBackward0>)
Epoch 1, Batch 3 processed loss : tensor([1.1919], grad_fn=<ViewBackward0>)
Epoch 1, Batch 4 processed loss : tensor([1.0681], grad_fn=<ViewBackward0>)
Epoch 1, Batch 5 processed loss : tensor([0.9802], grad_fn=<ViewBackward0>)
Epoch 1, Batch 6 processed loss : tensor([1.2610], grad_fn=<ViewBackward0>)
Epoch 1, Batch 7 processed loss : tensor([1.2560], grad_fn=<ViewBackward0>)
Epoch 1, Batch 8 processed loss : tensor([1.0526], grad_fn=<ViewBackward0>)
Epoch 1, Batch 9 processed loss : tensor([0.9356], grad_fn=<ViewBackward0>)
Epoch 1, Batch 10 processed loss : tensor([1.0845], grad_fn=<ViewBackward0>)
Epoch 1, Batch 11 processed loss : tensor([1.2025], grad_fn=<ViewBackward0>)
Epoch 1, Batch 12 processed loss : tensor([1.1509], grad_fn=<ViewBackward0>)
Epoch 1, Batch 13 processed loss : tensor([1.2380], grad_fn=<ViewBackward0>)
Epoch 1,

In [11]:
terminate(rank=rank)

Rank 2 successfully terminated.
