In [47]:
import torch
import torch.nn as nn
import numpy as np  
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets, transforms
from nets.cnn import CNNCifar
import time

In [67]:
class CNNCifar(nn.Module):
    def __init__(self, model_rate):
        super(CNNCifar, self).__init__()

        pre_hidden_size = [64, 128, 256, 512]
        hidden_size=[int(np.ceil(i*model_rate))  for i in pre_hidden_size]
        self.hidden_size=hidden_size

        self.block1=self._make_block(0)
        self.block2=self._make_block(1)
        self.block3=self._make_block(2)
        self.block4=self._make_block(3)
        self.output=nn.Sequential(
            nn.Linear(hidden_size[-1], 10)
        )
        self.flatten=nn.Flatten(1)
        #self.logit_projector=nn.utils.parametrizations.orthogonal(nn.Linear(10, 10))
        projector_name='orthogonal_projector'
        setattr(self, projector_name, nn.utils.parametrizations.orthogonal(
            nn.Linear(pre_hidden_size[3], int(np.ceil(0.7*pre_hidden_size[3])))))
        projector_name='linear_projector'
        setattr(self, projector_name, nn.Linear(pre_hidden_size[3], int(np.ceil(0.7*pre_hidden_size[3]))))

    def _make_block(self, layer_idx):
        layers=list()
        if(layer_idx == 0):
            layers.append(nn.Conv2d(3, self.hidden_size[0], 3, 1, 1))
        else:
            layers.append(nn.Conv2d(self.hidden_size[layer_idx-1], self.hidden_size[layer_idx], 3, 1, 1))
        layers.append(nn.BatchNorm2d(self.hidden_size[layer_idx], momentum=None, track_running_stats=False))
        layers.append(nn.ReLU(inplace=True))
        if(layer_idx != 3):
            layers.append(nn.MaxPool2d(2))
        return nn.Sequential(*layers)
    
    def forward_feature(self, x):
        out=self.block1(x)
        out=self.block2(out)
        out=self.block3(out)
        out=self.block4(out)
        out=nn.AdaptiveAvgPool2d((1, 1))(out)
        out=self.flatten(out)
        #print(out.shape)
        return out
    
    def forward_head(self, x):
        out=self.output(x)
        return out
    
    def forward(self, x):
        out=self.forward_feature(x)
        #print(out.shape)
        out=self.forward_head(out)
        return out

In [68]:
transform_train=transforms.Compose([  
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),  
    transforms.ToTensor(),    
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])  
transform_test=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def cifar10_global(batch_size,root):
    dataset_train=datasets.CIFAR10(root, train=True, transform= transform_train, download=True)
    dataset_test=datasets.CIFAR10(root, train=False, transform= transform_test, download=True)
    dataloader_train=data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
    dataloader_test=data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)
    return dataloader_train, dataloader_test

In [73]:
teacher_model=CNNCifar(1.0)
student_model=CNNCifar(1.0)

batch_size=128
dataloader_train_global, dataloader_test_global=cifar10_global(batch_size, root='../../data/cifar10')

Files already downloaded and verified
Files already downloaded and verified


In [75]:
# Define the training function for the teacher model
def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train_teacher(model, dataloader, epochs, criterion, optimizer, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        test_acc = test(model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

# Define the distillation function
def distill(teacher_model, student_model, dataloader, epochs, criterion, optimizer, device, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    criterion1 = nn.MSELoss()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_features = teacher_model.forward_feature(inputs)
                teacher_outputs = teacher_model.forward_head(teacher_features) 
            T=2.0
            student_features = student_model.forward_feature(inputs)
            student_outputs = student_model.forward_head(student_features)
            student_features=nn.functional.log_softmax(student_features/T, dim=1)
            teacher_features=nn.functional.softmax(teacher_features/T, dim=1)
            teacher_outputs=nn.functional.softmax(teacher_outputs/T, dim=1)
            student_outputs=nn.functional.log_softmax(student_outputs/T, dim=1)
            loss=(T**2)*criterion1(student_features, teacher_features)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_acc = test(student_model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move models to device
teacher_model.to(device)
student_model.to(device)

# Define loss criterion and optimizers
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss(reduction='batchmean')
teacher_optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.001)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

# Training and distillation process
communication_rounds = 10
teacher_epochs = 10
distill_epochs = 10

for round in range(communication_rounds):
    print(f"Communication Round {round+1}/{communication_rounds}")
    # Train teacher model
    start=time.time()
    train_teacher(teacher_model, dataloader_train_global, teacher_epochs, criterion1, teacher_optimizer, device)
    print(f"Teacher training time: {time.time()-start}")
    # Distill knowledge to student model
    distill(teacher_model, student_model, dataloader_test_global, distill_epochs, criterion2, student_optimizer, device)
    print(f"Distillation time: {time.time()-start}")
    

Communication Round 1/10
Epoch 1/10, Loss: 0.43763516915728673, acc: 0.8128
Epoch 2/10, Loss: 0.40561784967742004, acc: 0.8283
Epoch 3/10, Loss: 0.3919153481798099, acc: 0.8311
Epoch 4/10, Loss: 0.3706304353978628, acc: 0.8302
Epoch 5/10, Loss: 0.3565370174853698, acc: 0.8332
Epoch 6/10, Loss: 0.33763699915707873, acc: 0.8333
Epoch 7/10, Loss: 0.3179672200356603, acc: 0.8441
Epoch 8/10, Loss: 0.30787154155619006, acc: 0.8418
Epoch 9/10, Loss: 0.29471771278039877, acc: 0.849
Epoch 10/10, Loss: 0.2832119171805394, acc: 0.8381
Teacher training time: 265.8195023536682
Epoch 1/10, Loss: 159.71136358719838, acc: 0.5185
Epoch 2/10, Loss: 158.41102233114123, acc: 0.6337
Epoch 3/10, Loss: 157.74387330646758, acc: 0.6912
Epoch 4/10, Loss: 157.2864949673037, acc: 0.7268
Epoch 5/10, Loss: 156.95645566529865, acc: 0.7632
Epoch 6/10, Loss: 156.74044027207773, acc: 0.7736
Epoch 7/10, Loss: 156.65734245203717, acc: 0.7809
Epoch 8/10, Loss: 156.6123147312599, acc: 0.7873
Epoch 9/10, Loss: 156.521092861

KeyboardInterrupt: 