In [7]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import socket
import pickle
import struct
from time import sleep
import struct  # For packing/unpacking data size
from typing import Any, Dict, List, Tuple, Set
import time
from models import myResNet, SimpleModel
from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
import multiprocessing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

device: cuda


In [8]:
def galore_parameters(model):
    galore_params = []
    non_galore_params = []
    for name, param in model.named_parameters():
        if 'embeddings' in name and not 'LayerNorm' in name:
            galore_params.append(param)
            continue
        
        if 'layer' in name and 'weight' in name and not 'LayerNorm' in name:
            galore_params.append(param)
            continue

        if 'classifier' in name and not 'bias' in name:
            galore_params.append(param)
            continue
                    
        else:
            non_galore_params.append(param)
            
    param_groups = [{'params': non_galore_params},
                    {'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}]   # 'proj_type': 'std', 'reverse_std','right', 'left', 'full'

    for param in galore_params:
        if param.dim() != 2:
            raise ValueError('Galore only supports 2D parameters')

    return param_groups


In [9]:
class Worker:
    def __init__(self, worker_id, host="localhost", port=60000):
        self.worker_id = worker_id
        self.server_host = host
        self.server_port = port
        self.network_latency_list = []
        self.load_data()

    def calc_network_latency(self, is_send):
        self.network_latency_list.append(self.end_time - self.start_time)
        if is_send:
            print(f'Send Network latency: {self.end_time - self.start_time}')
        else:
            print(f'Recv Network latency: {self.end_time - self.start_time}')
        # reset after calculation
        self.start_time = 0
        self.end_time = 0

    def print_total_network_latency(self):
        print(f'Total network latency for worker {self.worker_id}: {sum(self.network_latency_list)}')

    def load_data(self):
        # Load the dataloader for this worker
        with open(f"dataloader_{self.worker_id}.pkl", "rb") as f:
            self.dataloader = pickle.load(f)

    def send_data(self, sock, data):
        """Helper function to send data with a fixed-length header."""
        # Serialize the data
        data_bytes = pickle.dumps(data)

        # clock starts
        self.start_time = time.perf_counter()

        # Send the size of the data first
        sock.sendall(struct.pack("!I", len(data_bytes)))

        # Send the actual data
        sock.sendall(data_bytes)

        # clock ends
        self.end_time = time.perf_counter()
        self.calc_network_latency(True)

    def recv_data(self, sock):
        """Helper function to receive data with a fixed-length header."""
        # Receive the size of the incoming data
        size_data = sock.recv(4)
        if not size_data:
            return None
        size = struct.unpack("!I", size_data)[0]
        
        # clock starts
        self.start_time = time.perf_counter()

        # Receive the actual data
        data = b""
        while len(data) < size:
            packet = sock.recv(size - len(data))
            if not packet:
                return None
            data += packet

        # clock ends
        self.end_time = time.perf_counter()
        self.calc_network_latency(False)

        return pickle.loads(data)

    def send_recv(self, gradients) -> Tuple[bool, Any]:
        # Send gradients to the server
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.connect((self.server_host, self.server_port))
            print(f"Worker {self.worker_id} connected to server.")

            # Send gradients
            self.send_data(s, gradients)

            # print the gradients
            print(f"Worker {self.worker_id} sent gradients {gradients}.")

            # Receive averaged gradients
            avg_gradients = self.recv_data(s)
            if avg_gradients is None:
                return (False, None)

        return (True, avg_gradients)

    def train_worker(self):
        # Create a model
        model = SimpleModel()
        # model = myResNet()
        param_groups = galore_parameters(model)
        model = model.to(device)
        # optimizer = GaLoreAdamW(param_groups, lr=0.01)
        optimizer = optim.SGD(model.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        # optimizer_dict = {}
        # # GaLore:
        # # define an optimizer for each parameter p, and store them in optimizer_dict
        # for p in model.parameters():
        #     if p.requires_grad:
        #         optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01)

        # # define a hook function to update the parameter p during the backward pass
        # def optimizer_hook(p):
        #     if p.grad is None: 
        #         return
        #     optimizer_dict[p].step()
        #     optimizer_dict[p].zero_grad()

        # # Register the hook onto every parameter
        # for p in model.parameters():
        #     if p.requires_grad:
        #         p.register_post_accumulate_grad_hook(optimizer_hook)

        # Training loop
        for epoch in range(5):
            for batch_X, batch_y in self.dataloader:
                batch_X = batch_X.to(device)
                batch_y = batch_y.to(device)
                # Forward pass
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)

                # Backward pass
                optimizer.zero_grad()
                loss.backward()

                # Get gradients
                gradients = {name: param.grad for name, param in model.named_parameters()}

                # Print the size of gradients
                for name, grad in gradients.items():
                    print(f"Gradient size for {name}: {grad.size()}")

                # Send gradients to the server
                update, avg_gradients = self.send_recv(gradients)

                if not update:
                    print(f"Worker {worker_id} failed to receive averaged gradients.")
                    continue

                print(f"Worker {worker_id} received averaged gradients {avg_gradients}.")

                # Update model parameters with averaged gradients
                for name, param in model.named_parameters():
                    param.grad = avg_gradients[name]
                optimizer.step()

            print(f"Worker {worker_id} completed epoch {epoch}")

        print(f"Worker {worker_id} finished training.")


In [10]:
def run_worker(worker_id):
    # Import or define your Worker class here, or import from a module
    worker = Worker(worker_id)
    worker.train_worker()
    worker.print_total_network_latency()

In [11]:
import subprocess

# Define worker IDs and corresponding log file names
worker_ids = [0, 1, 2]
log_files = [f'logs/worker_log_{wid}.txt' for wid in worker_ids]

# Start worker processes and redirect output to logs
processes = []
for wid, log_file in zip(worker_ids, log_files):
    with open(log_file, 'w') as f:
        p = subprocess.Popen(['python', 'worker.py', str(wid)], stdout=f, stderr=f)
        processes.append(p)

# Wait for all processes to complete
for p in processes:
    p.wait()

print("All workers have completed execution.")


All workers have completed execution.
