In [1]:
import socket
import struct
import numpy as np
import numpy as np
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from torchvision.transforms import ToTensor
from flwr_datasets.visualization import plot_label_distributions
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
host = '127.0.0.1'
port = 5000
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((host, port))
server_socket.listen(8)

array = np.array([1.0, 2.5, 3.5, 4.5, 5.5], dtype=np.float32)
shape = array.shape
shape_data = struct.pack('!' + 'I' * len(shape), *shape)
shape_size = len(shape_data)
array_data = array.tobytes()
array_length = len(array_data)
packet=struct.pack('!I', shape_size) + shape_data + struct.pack('!I', array_length) + array_data

In [None]:
node_s = []
node_r = []

try:
    while True:
        client_socket, addr = server_socket.accept()
        server_socket.settimeout(1)
        data = client_socket.recv(1024).decode()
        if data == "Server-R":
            server_s = client_socket
        elif data == "Server-S":
            server_r = client_socket
        elif data == "Node-R":
            node_s.append(client_socket)
        elif data == "Node-S":
            node_r.append(client_socket)
        client_socket.sendall(struct.pack('I',len(b"start"))+b"start")
except socket.timeout:
    print('Timeout')
    server_socket.settimeout(None)

for tmp_socket in node_r:
    tmp_socket.recv(1024)
server_r.recv(1024)

In [None]:
for _ in range(100):
    server_s.sendall(struct.pack('I',len(packet))+packet)

In [None]:
for tmp_socket in node_r:
    tmp_socket.close()
for tmp_socket in node_s:
    tmp_socket.close()
server_s.close()
server_r.close()

In [None]:
# Define the Transformer model
class TransformerModel(nn.Module):
    def __init__(self, input_dim, num_classes, num_heads, num_layers, hidden_dim):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
        )
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        # Reshape input for embedding
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # Flatten the input (batch_size, 32*32*3)
        x = self.embedding(x)  # (batch_size, hidden_dim)
        
        # Reshape for transformer
        x = x.unsqueeze(0)  # Add sequence length dimension (1, batch_size, hidden_dim)
        x = self.transformer(x, x)  # Pass through transformer
        x = x.squeeze(0)  # Remove sequence length dimension (batch_size, hidden_dim)
        
        x = self.fc(x)  # Final classification layer
        return x

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Initialize the model, loss function, and optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
input_dim = 32 * 32 * 3  # CIFAR-10 images are 32x32 with 3 color channels
num_classes = 10
num_heads = 4
num_layers = 2
hidden_dim = 128

model = TransformerModel(input_dim, num_classes, num_heads, num_layers, hidden_dim).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [None]:
def run_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, device: torch.device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for batch in dataloader:
        inputs = batch["img"].view(-1, 32 * 32 * 3).to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (outputs.argmax(dim=1) == labels).sum().item()
        total_samples += labels.size(0)
    
    print(f"Training loss: {total_loss / len(dataloader)}, accuracy: {total_correct / total_samples}")

def evaluate_model(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for batch in dataloader:
            inputs = batch["img"].view(-1, 32 * 32 * 3).to(device)
            labels = batch["label"].to(device)
            # inputs = inputs.view(-1, 32 * 32 * 3).to(device)
            # labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            total_correct += (outputs.argmax(dim=1) == labels).sum().item()
            total_samples += labels.size(0)
    
    print(f"Validation loss: {total_loss / len(dataloader)}, accuracy: {total_correct / total_samples}")

In [None]:
fds = FederatedDataset(
    dataset="cifar10",
    partitioners={
        "train": DirichletPartitioner(
            num_partitions=50,
            partition_by="label",
            alpha=0.1,
            seed=42,
            min_partition_size=0,
        ),
    },
)


transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

def train_transforms(batch):
  transforms = transform_train
  batch["img"] = [transforms(img) for img in batch["img"]]
  return batch

def test_transforms(batch):
    transforms = transform_test
    batch["img"] = [transforms(img) for img in batch["img"]]
    return batch

partition = fds.load_partition(0, "train").with_transform(train_transforms)
centralized_dataset = fds.load_split("test").with_transform(test_transforms)
train_loader = DataLoader(partition, batch_size=1, shuffle=True)
test_loader = DataLoader(centralized_dataset, batch_size=32, shuffle=False)

In [4]:
from models.vit_small import ViT
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout=0.1,
    emb_dropout=0.1
).to(device)


optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [7]:
def train_model(model: nn.Module, 
                train_loader: DataLoader, 
                criterion: nn.Module, 
                device: torch.device, 
                scaler: torch.cuda.amp.GradScaler, 
                optimizer: torch.optim.Optimizer):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for batch in train_loader:
        inputs = batch["img"].to(device)
        labels = batch["label"].to(device)
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        total_loss += loss.item()
        total_samples += labels.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
    print(f"Train Loss: {total_loss / total_samples:.4f}, Train Accuracy: {total_correct / total_samples:.4f}")

In [9]:
train_model(net, train_loader, criterion, device, scaler, optimizer)

Train Loss: 1.4705, Train Accuracy: 0.3539
