In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from typing import List
import time

# Simple MLP model
class SimpleMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, num_classes=10):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Worker class simulating a distributed worker
class Worker:
    def __init__(self, worker_id, model, optimizer, loss_fn, device='cpu'):
        self.worker_id = worker_id
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = loss_fn
    
    def compute_gradients(self, data, target):
        """Compute gradients on local data batch"""
        self.model.train()
        self.optimizer.zero_grad()
        
        data, target = data.to(self.device), target.to(self.device)
        output = self.model(data)
        loss = self.criterion(output, target)
        loss.backward()
        
        # Extract gradients
        gradients = []
        for param in self.model.parameters():
            if param.grad is not None:
                gradients.append(param.grad.clone())
            else:
                gradients.append(torch.zeros_like(param))
        
        return gradients, loss.item()

# Parameter Server for synchronization
class ParameterServer:
    def __init__(self, model):
        self.global_model = model
    
    def aggregate_gradients(self, worker_gradients: List[List[torch.Tensor]]):
        """Average gradients from all workers"""
        num_workers = len(worker_gradients)
        aggregated_grads = []
        
        # Average gradients across workers
        num_params = len(worker_gradients[0])
        for param_idx in range(num_params):
            grad_sum = sum(worker_gradients[worker_id][param_idx] 
                          for worker_id in range(num_workers))
            aggregated_grads.append(grad_sum / num_workers)
        
        return aggregated_grads
    
    def update_global_model(self, aggregated_grads, lr=0.01):
        """Update global model with aggregated gradients"""
        with torch.no_grad():
            for param, grad in zip(self.global_model.parameters(), aggregated_grads):
                param -= lr * grad
    
    def broadcast_model(self, workers: List[Worker]):
        """Send updated model to all workers"""
        for worker in workers:
            worker.model.load_state_dict(self.global_model.state_dict())

# Data-Parallel Distributed Training Simulator
class DistributedTrainer:
    def __init__(self, num_workers=4, device='cpu'):
        self.num_workers = num_workers
        self.device = device
        
        # Load MNIST dataset
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        self.train_dataset = datasets.MNIST('./data', train=True, download=True, 
                                           transform=transform)
        self.test_dataset = datasets.MNIST('./data', train=False, transform=transform)
        
        # Create global model and parameter server
        self.global_model = SimpleMLP().to(device)
        self.param_server = ParameterServer(self.global_model)
        
        # Create workers
        self.workers = []
        for i in range(num_workers):
            worker_model = SimpleMLP().to(device)
            worker_model.load_state_dict(self.global_model.state_dict())
            self.workers.append(Worker(i, worker_model, device))
        
        self.train_losses = []
        self.test_accuracies = []
    
    def split_batch(self, data, target):
        """Split batch among workers"""
        batch_size = data.size(0)
        chunk_size = batch_size // self.num_workers
        
        data_chunks = []
        target_chunks = []
        
        for i in range(self.num_workers):
            start_idx = i * chunk_size
            end_idx = start_idx + chunk_size if i < self.num_workers - 1 else batch_size
            data_chunks.append(data[start_idx:end_idx])
            target_chunks.append(target[start_idx:end_idx])
        
        return data_chunks, target_chunks
    
    def train_epoch(self, dataloader):
        """Train for one epoch using data-parallel approach"""
        epoch_loss = 0
        num_batches = 0
        
        for batch_idx, (data, target) in enumerate(dataloader):
            # Step 1: Split batch among workers
            data_chunks, target_chunks = self.split_batch(data, target)
            
            # Step 2: Each worker computes gradients on their chunk
            worker_gradients = []
            batch_losses = []
            
            for worker_id, worker in enumerate(self.workers):
                grads, loss = worker.compute_gradients(
                    data_chunks[worker_id], 
                    target_chunks[worker_id]
                )
                worker_gradients.append(grads)
                batch_losses.append(loss)
            
            # Step 3: Parameter server aggregates gradients
            aggregated_grads = self.param_server.aggregate_gradients(worker_gradients)
            
            # Step 4: Update global model
            self.param_server.update_global_model(aggregated_grads, lr=0.01)
            
            # Step 5: Broadcast updated model to workers
            self.param_server.broadcast_model(self.workers)
            
            epoch_loss += np.mean(batch_losses)
            num_batches += 1
            
            if batch_idx % 100 == 0:
                print(f"  Batch {batch_idx}/{len(dataloader)}, "
                      f"Avg Loss: {np.mean(batch_losses):.4f}")
        
        return epoch_loss / num_batches
    
    def evaluate(self):
        """Evaluate global model on test set"""
        self.global_model.eval()
        correct = 0
        total = 0
        
        test_loader = DataLoader(self.test_dataset, batch_size=1000, shuffle=False)
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.global_model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)
        
        accuracy = 100. * correct / total
        return accuracy
    
    def train(self, num_epochs=5, batch_size=256):
        """Main training loop"""
        train_loader = DataLoader(self.train_dataset, batch_size=batch_size, 
                                 shuffle=True)
        
        print(f"Starting Data-Parallel Distributed Training")
        print(f"Number of workers: {self.num_workers}")
        print(f"Batch size: {batch_size} (split into {batch_size//self.num_workers} per worker)")
        print("="*60)
        
        for epoch in range(num_epochs):
            start_time = time.time()
            
            avg_loss = self.train_epoch(train_loader)
            test_acc = self.evaluate()
            
            epoch_time = time.time() - start_time
            
            self.train_losses.append(avg_loss)
            self.test_accuracies.append(test_acc)
            
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print(f"  Average Loss: {avg_loss:.4f}")
            print(f"  Test Accuracy: {test_acc:.2f}%")
            print(f"  Time: {epoch_time:.2f}s")
            print("="*60)

    def get_train_losses(self):
        return self.train_lossess

    def get_test_accuracies(self):
        return self.test_accuracies
    
    def plot_results(self):
        """Plot training curves"""
        

In [None]:
def plot_result(train_lossess, test_accuracies):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
    ax1.plot(train_losses)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.grid(True)
    
    ax2.plot(self.test_accuracies)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Test Accuracy (%)')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('distributed_training_results.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
from matplotlib import pyplot as plt
torch.manual_seed(42)
np.random.seed(42)

# Create trainer with 4 workers
trainer = DistributedTrainer(num_workers=4, device='cpu')

# Train the model
trainer.train(num_epochs=5, batch_size=256)

# Plot results
trainer.plot_results()