In [1]:
import os
import time
import copy
import random
import math
import os
import time
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Subset

from torchvision import datasets
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image

import IPython

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


## Start Connecting

In [3]:
train_dataset_CIFAR10 = datasets.CIFAR10(root='data',
                                       train=True,
                                       transform=transforms.ToTensor(),
                                       download=True)
test_dataset_CIFAR10 = datasets.CIFAR10(root='data',
                                      train=False,
                                      transform=transforms.ToTensor(),
                                      download=True)
train_loader_CIFAR10 = DataLoader(dataset=train_dataset_CIFAR10,
                                  batch_size=64,
                                  shuffle=True)
test_loader_CIFAR10 = DataLoader(dataset=test_dataset_CIFAR10,
                                 batch_size=64,
                                 shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
data_tensor, label_tensor = torch.load('/workspace/Taein/jooyeong/data/cifar10_split_1.pt')

# 텐서 데이터를 데이터셋으로 변환
client_dataset = TensorDataset(torch.stack(data_tensor), torch.tensor(label_tensor))

# DataLoader를 사용하여 데이터셋 로드
client_loader = DataLoader(client_dataset, batch_size=32, shuffle=True)

FileNotFoundError: [Errno 2] No such file or directory: '/workspace/Taein/jooyeong/data/cifar10_split_1.pt'

# ResNet18 Model

In [4]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock_18(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock_18, self).__init__()
        self.conv_1 = conv3x3(inplanes, planes, stride)
        self.bn_1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv_2 = conv3x3(planes, planes)
        self.bn_2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv_1(x)
        out = self.bn_1(out)
        out = self.relu(out)

        out = self.conv_2(out)
        out = self.bn_2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.relu(out)

        return out

class ResNet_18(nn.Module):
    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3

        super(ResNet_18, self).__init__()
        self.conv1g = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bng = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512*block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes*block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1g(x)
        x = self.bng(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x_full = x.view(x.size(0), -1)
        logits = self.fc(x_full)
        #probas = F.softmax(logits, dim=1)
        return logits, x_full

def resnet18(num_classes, grayscale):
    model = ResNet_18(block=BasicBlock_18, layers=[2,2,2,2], num_classes=num_classes, grayscale=grayscale)
    return model

# Client Model Training

In [5]:
criterion = nn.CrossEntropyLoss()

def avg_train_client(id, client_loader, global_model, num_local_epochs, lr):
    local_model = copy.deepcopy(global_model)
    local_model = local_model.to(device)
    local_model.train()
    optimizer = torch.optim.SGD(local_model.parameters(), lr=lr, momentum = 0.9)

    for epoch in range(num_local_epochs):
        print('    Epoch {}'.format(epoch+1))
        for (i, (x,y)) in enumerate(client_loader):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            local_out, _ = local_model(x)
      
            loss = criterion(local_out, y)
            loss.backward()
            optimizer.step()
            

    
    
    return local_model

In [6]:
def log_message(message):
    current_time = time.strftime("%M:%S", time.gmtime())  # [mm:ss] 포맷으로 시간 가져오기
    print(f"[{current_time}] {message}")

In [10]:
import socket
import time
import multiprocessing

server_address = ('127.0.0.1', 9090)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(server_address)
id = 0

criterion = nn.CrossEntropyLoss()


def getId():
    data = s.recv(1024).decode('utf-8')
    id = int(data)
    log_message(f"client id : {id}")
    return id


def recvFile(fileName):

    # 서버에서 보낼 때 다음으로 보낼 정보가 파일인지 메시지인지 알려준다.
    msg = s.recv(1024).decode('utf-8')
    if (msg == "end"):
        return msg

    # FileHandler.writeFile과 매핑
    with open(fileName, 'wb') as f:
        file_size = int.from_bytes(s.recv(8), byteorder='big')
        log_message(f"file size : {file_size}")
        received_size = 0
        while received_size < file_size:
            data = s.recv(4096)
            f.write(data)
            received_size += len(data)
    log_message(f"{fileName} received and saved")
    return msg


def sendFile(fileName):
    with open(fileName, 'rb') as f:
        file_size = os.path.getsize(fileName)
        log_message(f"Sending {fileName} to server, size : {file_size}bytes")
        s.sendall(file_size.to_bytes(8, byteorder='big'))    # 파일 크기 전송
        data = f.read(4096)
        while data:
            s.sendall(data)
            data = f.read(4096)
    log_message(f"{fileName} sent to server")


def Learning(id, round):

    log_message("Start Learning")
    model = resnet18(10, False)
    model = torch.load('./global_model.pt')  # 이 부분 나중에 수정
    model.to(device)

    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(1):
        for (i, (x, y)) in enumerate(train_loader_CIFAR10):
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            local_out, _ = model(x)
            loss = criterion(local_out, y)
            loss.backward()
            optimizer.step()
        log_message('Round {}  Client {} Training Success'.format(round, id))

    torch.save(model, f'client_model_{id}.pt')
    log_message("End Learning")


def getUpdatedPT(id, round):  # 여기서 end 메시지도 받음
    msg = recvFile(f"global_model.pt")

    if (msg == "end"):
        return "end"

    Learning(id, round)
    sendFile(f"client_model_{id}.pt")
    msg = "done learning\n"
    s.sendall(msg.encode('utf-8'))


def run():
    log_message("시작")
    log_message(f"서버에 접속 중: {s.getpeername()}")
    id = getId()
    round = 1
    while True:
        log_message(f"Round {round} start")
        t = getUpdatedPT(id, round)
        round += 1
        if (t == "end"):
            log_message("종료 코드 수신")
            s.sendall("끝")
            break


if __name__ == "__main__":
    run()

[49:54] 시작
[49:54] 서버에 접속 중: ('127.0.0.1', 9090)
[49:54] client id : 1
[49:54] Round 1 start
[49:54] file size : 44817908
4096
8192
12288
16384
20480
24576
28672
32768
36864
40960
45056
49152
53248
57344
61440
65536
69632
73728
77824
81920
86016
90112
94208
98304
102400
106496
110592
114688
118784
122880
126976
131072
135168
139264
143360
147456
151552
155648
159744
163840
167936
172032
176128
180224
184320
188416
192512
196608
200704
204800
208896
212992
217088
221184
225280
229376
233472
237568
241664
245760
249856
253952
258048
262144
266240
270336
274432
278528
282624
286720
290816
294912
299008
303104
307200
311296
315392
319488
323584
327680
331776
335872
339968
344064
348160
352256
356352
360448
364544
368640
372736
376832
380928
385024
389120
393216
397312
401408
405504
409600
413696
417792
421888
425984
430080
434176
438272
442368
446464
450560
454656
458752
462848
466944
471040
475136
479232
483328
487424
491520
495616
499712
503808
507904
512000
516096
520192
524288
528384
5

  model = torch.load('./global_model.pt')  # 이 부분 나중에 수정


[50:04] Round 1  Client 1 Training Success
[50:04] End Learning
[50:04] Sending client_model_1.pt to server, size : 44819760bytes
[50:04] client_model_1.pt sent to server
[50:04] Round 2 start
[50:06] file size : 44819636
4096
8192
12288
16384
20480
24576
28672
32768
36864
40960
45056
49152
53248
57344
61440
65536
69632
73728
77824
81920
86016
90112
94208
98304
102400
106496
110592
114688
118784
122880
126976
131072
135168
139264
143360
147456
151552
155648
159744
163840
167936
172032
176128
180224
184320
188416
192512
196608
200704
204800
208896
212992
217088
221184
225280
229376
233472
237568
241664
245760
249856
253952
258048
262144
266240
270336
274432
278528
282624
286720
290816
294912
299008
303104
307200
311296
315392
319488
323584
327680
331776
335872
339968
344064
348160
352256
356352
360448
364544
368640
372736
376832
380928
385024
389120
393216
397312
401408
405504
409600
413696
417792
421888
425984
430080
434176
438272
442368
446464
450560
454656
458752
462848
466944
471040

ConnectionResetError: [WinError 10054] 현재 연결은 원격 호스트에 의해 강제로 끊겼습니다