In [1]:
import torch
import torch.nn as nn
import numpy as np  
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets, transforms
from nets.cnn import CNNCifar
import time

In [4]:
np.random.seed(0)
for i in range(20):
    selected_clients = np.random.choice(200, 10, replace=False)
    print(selected_clients)

[ 18 170 107  98 177 182   5 146  12 152]
[113  50  68 155  57  82  40 105 186 130]
[ 49 136 166 146 198 191 178 156  34  85]
[130 122 113  14  61 103 186 168  60 159]
[115 127 166  27  12   1 108 120  28 121]
[ 53 120 134 113 160 122  35 101 167 148]
[  6 127 198  62 163 153 188 103 190  64]
[195  58 163  44   8  92 191 124 174 156]
[  6  60 130 120  70 165   1 100 155  28]
[ 76  13  38 120 168 114 197 130 156 112]
[ 39 140  85  77  28 163  81  25  19 148]
[ 32 156 149 197 115   1 131  65  24  52]
[  2 118 152  95  20  49 183  40  96 140]
[183  70 101  11  60  51 167 124  50   5]
[175  14  93 139 118  41 131  78  55  75]
[ 52  58 154 114  87  19  53 179  81  26]
[184 117 157  75 121  32 108  73  33 178]
[ 73  76 117  98 154  41  71 140 122 139]
[ 79  69  51  18  20 140 148  94  45  46]
[108  60  56 184 153  96 113  13 197 119]


In [2]:
class CNNCifar(nn.Module):
    def __init__(self, model_rate):
        super(CNNCifar, self).__init__()

        pre_hidden_size = [64, 128, 256, 512]
        hidden_size=[int(np.ceil(i*model_rate))  for i in pre_hidden_size]
        self.hidden_size=hidden_size

        self.block1=self._make_block(0)
        self.block2=self._make_block(1)
        self.block3=self._make_block(2)
        self.block4=self._make_block(3)
        self.output=nn.Sequential(
            nn.Linear(hidden_size[-1], 10)
        )
        self.flatten=nn.Flatten(1)
        #self.logit_projector=nn.utils.parametrizations.orthogonal(nn.Linear(10, 10))
        projector_name='orthogonal_projector'
        setattr(self, projector_name, nn.utils.parametrizations.orthogonal(
            nn.Linear(pre_hidden_size[3], int(np.ceil(0.7*pre_hidden_size[3])))))
        projector_name='linear_projector'
        setattr(self, projector_name, nn.Linear(pre_hidden_size[3], int(np.ceil(0.7*pre_hidden_size[3]))))

    def _make_block(self, layer_idx):
        layers=list()
        if(layer_idx == 0):
            layers.append(nn.Conv2d(3, self.hidden_size[0], 3, 1, 1))
        else:
            layers.append(nn.Conv2d(self.hidden_size[layer_idx-1], self.hidden_size[layer_idx], 3, 1, 1))
        layers.append(nn.BatchNorm2d(self.hidden_size[layer_idx], momentum=None, track_running_stats=False))
        layers.append(nn.ReLU(inplace=True))
        if(layer_idx != 3):
            layers.append(nn.MaxPool2d(2))
        return nn.Sequential(*layers)
    
    def forward_feature(self, x):
        out=self.block1(x)
        out=self.block2(out)
        out=self.block3(out)
        out=self.block4(out)
        out=nn.AdaptiveAvgPool2d((1, 1))(out)
        out=self.flatten(out)
        #print(out.shape)
        return out
    
    def forward_head(self, x):
        out=self.output(x)
        return out
    
    def forward(self, x):
        out=self.forward_feature(x)
        #print(out.shape)
        out=self.forward_head(out)
        return out

In [3]:
transform_train=transforms.Compose([  
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),  
    transforms.ToTensor(),    
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])  
transform_test=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def cifar10_global(batch_size,root):
    dataset_train=datasets.CIFAR10(root, train=True, transform= transform_train, download=True)
    dataset_test=datasets.CIFAR10(root, train=False, transform= transform_test, download=True)
    dataloader_train=data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
    dataloader_test=data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)
    return dataloader_train, dataloader_test

In [13]:
teacher_model=CNNCifar(0.7)
student_model=CNNCifar(1.0)

batch_size=128
dataloader_train_global, dataloader_test_global=cifar10_global(batch_size, root='../../data/cifar10')

Files already downloaded and verified
Files already downloaded and verified


In [14]:
# Define the training function for the teacher model
def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train_teacher(model, dataloader, epochs, criterion, optimizer, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        test_acc = test(model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

# Define the distillation function
def distill(teacher_model, student_model, dataloader, epochs, criterion, optimizer, device, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    criterion1 = nn.MSELoss()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_features = teacher_model.forward_feature(inputs)
                teacher_outputs = teacher_model.forward_head(teacher_features) 
            T=2.0
            student_features = student_model.forward_feature(inputs)
            student_outputs = student_model.forward_head(student_features)
            student_features = student_model.orthogonal_projector(student_features)
            student_features=nn.functional.log_softmax(student_features/T, dim=1)
            teacher_features=nn.functional.softmax(teacher_features/T, dim=1)
            teacher_outputs=nn.functional.softmax(teacher_outputs/T, dim=1)
            student_outputs=nn.functional.log_softmax(student_outputs/T, dim=1)
            loss=(T**2)*criterion1(student_features, teacher_features)+ (T**2)*criterion(student_outputs, teacher_outputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_acc = test(student_model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move models to device
teacher_model.to(device)
student_model.to(device)

# Define loss criterion and optimizers
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss(reduction='batchmean')
teacher_optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.001)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

# Training and distillation process
communication_rounds = 10
teacher_epochs = 10
distill_epochs = 10

for round in range(communication_rounds):
    print(f"Communication Round {round+1}/{communication_rounds}")
    # Train teacher model
    start=time.time()
    train_teacher(teacher_model, dataloader_train_global, teacher_epochs, criterion1, teacher_optimizer, device)
    print(f"Teacher training time: {time.time()-start}")
    # Distill knowledge to student model
    distill(teacher_model, student_model, dataloader_test_global, distill_epochs, criterion2, student_optimizer, device)
    print(f"Distillation time: {time.time()-start}")
    

Communication Round 1/10
Epoch 1/10, Loss: 1.2938785170350233, acc: 0.6189
Epoch 2/10, Loss: 0.9349009609588271, acc: 0.6871
Epoch 3/10, Loss: 0.8005198968950745, acc: 0.712
Epoch 4/10, Loss: 0.7175334631024725, acc: 0.7507
Epoch 5/10, Loss: 0.6621169444087827, acc: 0.7662
Epoch 6/10, Loss: 0.6160329605459862, acc: 0.7723
Epoch 7/10, Loss: 0.5832617718088048, acc: 0.7877
Epoch 8/10, Loss: 0.5479911274617285, acc: 0.7965
Epoch 9/10, Loss: 0.5205248168972142, acc: 0.7898
Epoch 10/10, Loss: 0.49528206301772076, acc: 0.798
Teacher training time: 127.73913908004761
Epoch 1/10, Loss: 141.07748046102404, acc: 0.5244
Epoch 2/10, Loss: 139.95803736433197, acc: 0.6249
Epoch 3/10, Loss: 139.47818746446055, acc: 0.6865
Epoch 4/10, Loss: 139.1932224321969, acc: 0.7205
Epoch 5/10, Loss: 139.0360431912579, acc: 0.7367
Epoch 6/10, Loss: 138.98013189774525, acc: 0.7385
Epoch 7/10, Loss: 138.94964947277987, acc: 0.7414
Epoch 8/10, Loss: 138.91169506990457, acc: 0.7489
Epoch 9/10, Loss: 138.8757760736006

KeyboardInterrupt: 