In [15]:
from data import *
import torch
import networkx as nx
from node import Node
from model import ConvNet, MNISTConvNet
from utils import *

In [16]:
NUM_NODES = 3 # 1.. 10

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
combined_dataset = CombinedDataset(
    data_path="./mnist_png/", num_nodes=NUM_NODES, overlap_pct=0.01, val_pct=0.1
)

In [19]:
node_dataloaders = {}

tfms = transforms.Compose(
    [transforms.Grayscale(), transforms.Resize((28, 28)), transforms.ToTensor()]
)

for node_id in range(NUM_NODES):
    node_dataloaders[node_id] = torch.utils.data.DataLoader(
        NodeDataset(combined_dataset, node_id, transform=tfms), batch_size=32, shuffle=True
    )

test_dataloaders = torch.utils.data.DataLoader(
    ValDataset(combined_dataset, transform=tfms), batch_size=32, shuffle=True
)

In [20]:
nodes = [
    Node(
        node_dataloaders[node_idx],
        MNISTConvNet,
        torch.optim.Adam,
        torch.nn.CrossEntropyLoss,
        device=device,
        model_kwargs={"num_filters": 3, "kernel_size": 5, "linear_width": 64}
    )
    for node_idx in range(NUM_NODES)
]

In [21]:
sd = torch.load("./init_weights.pth")

  sd = torch.load("./init_weights.pth")


In [22]:
nodes[0].model.load_state_dict(sd)

<All keys matched successfully>

In [23]:
# copy weights from one node to another
def copy_weights(from_node, to_node):
    to_node.model.load_state_dict(from_node.model.state_dict())


for node_idx in range(1, NUM_NODES):
    copy_weights(nodes[0], nodes[node_idx])

In [24]:
primal_loss = torch.nn.NLLLoss()

In [25]:
OITS = 2000
IITS = 2
RHO = 1.0
RHO_SCALING = 1.1

In [26]:
history = {
    "loss": [],
    "pred_loss": [],
    "accuracy": [],
}

In [27]:
g = nx.complete_graph(NUM_NODES)

In [None]:
# Custom Distributed training framework Implementation
accuracies ={}
losses = {}

for node_idx, node in enumerate(nodes):
    accuracies[node_idx] = []
    losses[node_idx] = []
    
for oit in range(OITS):    
    for node_idx, node in enumerate(nodes):
        
        for __ in range(IITS):
            node.optimizer.zero_grad()
            inputs, labels = node.get_next_batch()
            inputs, labels = inputs.to(device), labels.to(device)
            pred = node.model(inputs)
            loss = primal_loss(pred, labels) + (
                RHO * convergence_loss(node_idx, nodes, NUM_NODES, device)
            )
            loss.backward()
            node.optimizer.step()
        history["loss"].append(loss.item())
        losses[node_idx].append(loss.item())
        print(f"Node {node_idx} loss: {loss.item()}")
        # calculate accuracy
        if oit % 10 == 0:
            with torch.no_grad():
                correct = 0
                total = 0
                for inputs, labels in test_dataloaders:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = node.model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                accuracy = correct / total
                history["accuracy"].append(accuracy)
                accuracies[node_idx].append(accuracy)
                print(f"Node {node_idx} accuracy: {accuracy}")
