# Distributed Training with HuggingFace Trainer

This notebook implements distributed training using HuggingFace's Trainer with VGG-13 model on ImageNet dataset.

In [2]:
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from typing import *
import socket
import pickle
import struct
import time
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
from torchvision import transforms
import torchvision.models as models
import random
import numpy as np

In [3]:
class ImageNetDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item['image']
        label = item['label']
        
        if self.transform:
            image = self.transform(image)
            
        return {'pixel_values': image, 'labels': label}

In [4]:
class DistributedTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.server_host = kwargs.get('server_host', 'localhost')
        self.server_port = kwargs.get('server_port', 60000)
        self.worker_id = kwargs.get('worker_id', 0)
        self.network_latency_list = []
        self.start_time = 0
        self.end_time = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    def send_data(self, sock, data):
        """Helper function to send data with a fixed-length header."""
        # Serialize the data
        data_bytes = pickle.dumps(data)

        # clock starts
        self.start_time = time.perf_counter()

        # Send the size of the data first
        sock.sendall(struct.pack("!I", len(data_bytes)))

        # Send the actual data
        sock.sendall(data_bytes)

        # clock ends
        self.end_time = time.perf_counter()
        self.calc_network_latency(True)

    def recv_data(self, sock):
        """Helper function to receive data with a fixed-length header."""
        # Receive the size of the incoming data
        size_data = sock.recv(4)
        if not size_data:
            return None
        size = struct.unpack("!I", size_data)[0]
        
        # clock starts
        self.start_time = time.perf_counter()

        # Receive the actual data
        data = b""
        while len(data) < size:
            packet = sock.recv(size - len(data))
            if not packet:
                return None
            data += packet

        # clock ends
        self.end_time = time.perf_counter()
        self.calc_network_latency(False)

        return pickle.loads(data)

    def send_recv(self, gradients) -> Tuple[bool, Any]:
        # Send gradients to the server
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.connect((self.server_host, self.server_port))
            print(f"Worker {self.worker_id} connected to server.")

            # Send gradients
            self.send_data(s, gradients)

            # print the gradients
            print(f"Worker {self.worker_id} sent gradients {gradients}.")

            # Receive averaged gradients
            avg_gradients = self.recv_data(s)
            if avg_gradients is None:
                return (False, None)

        return (True, avg_gradients)

    def training_step(self, model, inputs):
        """
        Override the training step to implement distributed training
        """
        model.train()
        inputs = self._prepare_inputs(inputs)
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        loss.backward()

        # Get gradients
        gradients = {name: param.grad.cpu() for name, param in model.named_parameters()}

        # Send gradients to server and receive averaged gradients
        update, avg_gradients = self.send_recv(gradients)
        
        if not update:
            print(f"Worker {self.worker_id} failed to receive averaged gradients.")
            return loss.detach()

        # Update model parameters with averaged gradients
        for name, param in model.named_parameters():
            param.grad = avg_gradients[name].to(self.device)

        return loss.detach()

    def calc_network_latency(self, is_send):
        self.network_latency_list.append(self.end_time - self.start_time)
        if is_send:
            print(f'Send Network latency: {self.end_time - self.start_time}')
        else:
            print(f'Recv Network latency: {self.end_time - self.start_time}')
        # reset after calculation
        self.start_time = 0
        self.end_time = 0

    def print_total_network_latency(self):
        print(f'Total network latency for worker {self.worker_id}: {sum(self.network_latency_list)}')

In [5]:
class Worker:
    def __init__(self, worker_id, host="localhost", port=60000):
        self.worker_id = worker_id
        self.server_host = host
        self.server_port = port
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.setup_model_and_data()

    def setup_model_and_data(self):
        # Load VGG-13 model
        self.model = models.vgg13()
        self.model = self.model.to(self.device)

        # Load ImageNet100 dataset
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])

        # Load dataset (you'll need to implement proper dataset splitting)
        dataset = load_dataset("imagenet-1k")
        train_dataset = ImageNetDataset(dataset['train'], transform=transform)
        
        # Split dataset among workers (implement your splitting logic)
        # This is a simplified example
        total_size = len(train_dataset)
        worker_size = total_size // 3  # Assuming 3 workers
        start_idx = self.worker_id * worker_size
        end_idx = start_idx + worker_size if self.worker_id < 2 else total_size
        self.train_dataset = torch.utils.data.Subset(train_dataset, range(start_idx, end_idx))

        # Setup training arguments
        self.training_args = TrainingArguments(
            output_dir=f"./results_worker_{self.worker_id}",
            num_train_epochs=5,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            weight_decay=0.01,
            logging_dir=f"./logs_worker_{self.worker_id}",
            logging_steps=10,
            save_strategy="epoch",
            evaluation_strategy="epoch",
            load_best_model_at_end=True,
            push_to_hub=False,
        )

    def train_worker(self):
        # Initialize the distributed trainer
        trainer = DistributedTrainer(
            model=self.model,
            args=self.training_args,
            train_dataset=self.train_dataset,
            server_host=self.server_host,
            server_port=self.server_port,
            worker_id=self.worker_id
        )

        # Start training
        trainer.train()
        trainer.print_total_network_latency()

In [6]:
def run_worker(worker_id):
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    # Import or define your Worker class here, or import from a module
    worker = Worker(worker_id)
    worker.train_worker()
    worker.print_total_network_latency()

In [None]:
import subprocess

# Define worker IDs and corresponding log file names
worker_ids = [0, 1, 2]
log_files = [f'logs/worker_log_{wid}.txt' for wid in worker_ids]

# Start worker processes and redirect output to logs
processes = []
for wid, log_file in zip(worker_ids, log_files):
    with open(log_file, 'w') as f:
        p = subprocess.Popen(['python', 'worker.py', str(wid)], stdout=f, stderr=f)
        processes.append(p)

# Wait for all processes to complete
for p in processes:
    p.wait()

print("All workers have completed execution.")