In [30]:
import torch
import torch.nn as nn
import numpy as np  
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets, transforms
from nets.cnn import CNNCifar
import time

In [31]:
np.random.seed(0)

In [49]:
class CNNCifar(nn.Module):
    def __init__(self) :
        super(CNNCifar,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        self.linear = nn.Linear(16, 512)
        self.orthogonal = nn.utils.parametrizations.orthogonal(nn.Linear(16, 512))

    def forward_feature(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        return x
    
    def forward(self,x):
        x=self.conv1(x)
        x=F.relu(x)
        x=F.max_pool2d(x,2)
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(-1,16*5*5)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

In [35]:
class Block(nn.Module):
    def __init__(self, in_planes, planes, stride, track):
        super(Block, self).__init__()
        norm1=nn.BatchNorm2d(in_planes, momentum=None, track_running_stats=track)
        norm2=nn.BatchNorm2d(planes, momentum=None, track_running_stats=track)
        self.norm1=norm1
        self.conv1=nn.Conv2d(in_planes, planes, 3, padding=1, stride=stride, bias=False)
        self.norm2=norm2
        self.conv2=nn.Conv2d(planes, planes, 3, padding=1, stride=1, bias=False)
        self.shortcut=nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut=nn.Conv2d(in_planes, planes, 1, stride, bias=False)
            
    def forward(self, x):
        out = F.relu(self.norm1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.norm2(out)))
        out += shortcut
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, model_rate, num_blocks=[2, 2, 2, 2], num_classes=10):
        super(ResNet, self).__init__()
        pre_hidden_size=[64, 128, 256, 512]
        hidden_size=[int(np.ceil(i*model_rate))  for i in pre_hidden_size]
        self.in_planes = hidden_size[0]
        self.conv1=nn.Conv2d(3, hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False)

        self.layer1=self._make_layer(block, hidden_size[0], num_blocks[0], stride=1, track=False)
        self.layer2=self._make_layer(block, hidden_size[1], num_blocks[1], stride=2, track=False)
        self.layer3=self._make_layer(block, hidden_size[2], num_blocks[2], stride=2, track=False)
        self.layer4=self._make_layer(block, hidden_size[3], num_blocks[3], stride=2, track=False)
        self.output=nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(hidden_size[-1], num_classes)
        )
        #self.linear=nn.Linear(hidden_size[-1], num_classes)
            
    def _make_layer(self, block, planes, num_blocks, stride, track):
        strides = [stride] + [1]*(num_blocks-1)
        layers = list()
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, track))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def extract_feature(self, x):
        x=self.conv1(x)
        feat1=self.layer1(x)
        feat2=self.layer2(feat1)
        feat3=self.layer3(feat2)
        feat4=self.layer4(feat3)
        out=self.output(feat4)
        feat1=self.layer2[0].norm1(feat1)
        feat2=self.layer3[0].norm1(feat2)
        feat3=self.layer4[0].norm1(feat3)
        return [feat1, feat2, feat3, feat4], out
    
    def forward_feature(self, x):
        out=self.conv1(x)
        out=self.layer1(out)
        out=self.layer2(out)
        out=self.layer3(out)
        out=self.layer4(out)
        #print(out.shape)

        return out
    
    def forward_head(self, x):
        out=self.output(x)
        return out 
    
    def forward(self, x):
        out=self.forward_feature(x)
        out=self.forward_head(out)

        return out
    
def ResNet18Cifar(model_rate):
    model = ResNet(Block, model_rate, num_blocks=[2, 2, 2, 2], num_classes=10)
    return model

def ResNetCifar(model, model_rate):
    if model=='resnet18':
        return ResNet18Cifar(model_rate)

In [36]:
transform_train=transforms.Compose([  
    transforms.RandomCrop(32, padding=4),  
    transforms.RandomHorizontalFlip(),  
    transforms.ToTensor(),    
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])  
transform_test=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def cifar10_global(batch_size,root):
    dataset_train=datasets.CIFAR10(root, train=True, transform= transform_train, download=True)
    dataset_test=datasets.CIFAR10(root, train=False, transform= transform_test, download=True)
    dataloader_train=data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
    dataloader_test=data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False)
    return dataloader_train, dataloader_test

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model=ResNetCifar('resnet18',1.0)
student_model=CNNCifar()
teacher_model.to(device)
student_model.to(device)
batch_size=128
dataloader_train_global, dataloader_test_global=cifar10_global(batch_size, root='../../data/cifar10')

Files already downloaded and verified
Files already downloaded and verified


In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for i, (data, target) in enumerate(dataloader_train_global):
    data, target = data.to(device), target.to(device)
    print(data.shape)
    feature=student_model.forward_feature(data)
    print(feature.shape)
    break

torch.Size([128, 3, 32, 32])
torch.Size([128, 16, 5, 5])


In [41]:
def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [42]:
def train_student(model, dataloader, epochs, criterion, optimizer, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        test_acc = test(model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")
train_student(student_model, dataloader_train_global, 50, nn.CrossEntropyLoss(), torch.optim.Adam(student_model.parameters(), lr=0.001), device)

Epoch 1/50, Loss: 1.8175395471055795, acc: 0.4283
Epoch 2/50, Loss: 1.5708941858442849, acc: 0.4832
Epoch 3/50, Loss: 1.4593743524892862, acc: 0.5164
Epoch 4/50, Loss: 1.396874821704367, acc: 0.527
Epoch 5/50, Loss: 1.3403127821510101, acc: 0.5508
Epoch 6/50, Loss: 1.3004061933368674, acc: 0.5775
Epoch 7/50, Loss: 1.2585096715966149, acc: 0.5797
Epoch 8/50, Loss: 1.2262680600671207, acc: 0.6026
Epoch 9/50, Loss: 1.201399844168397, acc: 0.5917
Epoch 10/50, Loss: 1.1774942980093115, acc: 0.6002
Epoch 11/50, Loss: 1.159031412638057, acc: 0.608
Epoch 12/50, Loss: 1.148664189574054, acc: 0.6138
Epoch 13/50, Loss: 1.1229780852947089, acc: 0.6237
Epoch 14/50, Loss: 1.1104651183423484, acc: 0.6314
Epoch 15/50, Loss: 1.0990773965330685, acc: 0.6414
Epoch 16/50, Loss: 1.0854847303131963, acc: 0.6241
Epoch 17/50, Loss: 1.0819178382149133, acc: 0.6466
Epoch 18/50, Loss: 1.0638349330638681, acc: 0.6431
Epoch 19/50, Loss: 1.055958756095613, acc: 0.6393
Epoch 20/50, Loss: 1.0473766861974125, acc: 0.6

In [43]:
def train_teacher(model, dataloader, epochs, criterion, optimizer, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        test_acc = test(model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

train_teacher(teacher_model, dataloader_train_global, 50, nn.CrossEntropyLoss(), torch.optim.Adam(teacher_model.parameters(), lr=0.001), device)

Epoch 1/50, Loss: 1.7737682713267138, acc: 0.4409
Epoch 2/50, Loss: 1.222518632174148, acc: 0.569
Epoch 3/50, Loss: 0.9864564073055296, acc: 0.6285
Epoch 4/50, Loss: 0.8406389271816634, acc: 0.6828
Epoch 5/50, Loss: 0.7154518311743236, acc: 0.7263
Epoch 6/50, Loss: 0.6263538815481279, acc: 0.7624
Epoch 7/50, Loss: 0.5537493228912354, acc: 0.7834
Epoch 8/50, Loss: 0.5022377313860237, acc: 0.8096
Epoch 9/50, Loss: 0.454133230135264, acc: 0.8129
Epoch 10/50, Loss: 0.41729632820314766, acc: 0.8224
Epoch 11/50, Loss: 0.38236659349840313, acc: 0.8264
Epoch 12/50, Loss: 0.3600789384006539, acc: 0.8325
Epoch 13/50, Loss: 0.32946471183958564, acc: 0.8457
Epoch 14/50, Loss: 0.30436257850331117, acc: 0.8441
Epoch 15/50, Loss: 0.2860005771755562, acc: 0.8508
Epoch 16/50, Loss: 0.2703135870492367, acc: 0.8614
Epoch 17/50, Loss: 0.24511759587184853, acc: 0.8658
Epoch 18/50, Loss: 0.24065172169214624, acc: 0.8657
Epoch 19/50, Loss: 0.22084292727510643, acc: 0.857
Epoch 20/50, Loss: 0.2066577452870890

In [50]:
def logit_distill(teacher_model, student_model, dataloader, epochs, criterion, optimizer, device, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs) 
            T=2.0
            student_outputs = student_model(inputs)
            teacher_outputs=nn.functional.softmax(teacher_outputs/T, dim=1)
            student_outputs=nn.functional.log_softmax(student_outputs/T, dim=1)
            loss=(T**2)*criterion(student_outputs, teacher_outputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_acc = test(student_model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

def linear_feature_logit(teacher_model, student_model, dataloader, epochs, criterion, optimizer, device, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_feature = teacher_model.forward_feature(inputs) 
                b, c, h, w=teacher_feature.shape
                teacher_feature=teacher_feature.view(b, c, h*w).mean(-1)
                teacher_outputs = teacher_model(inputs)
            T=2.0
            student_feature = student_model.forward_feature(inputs)
            b, c, h, w=student_feature.shape
            student_feature=student_feature.view(b, c, h*w).mean(-1)
            student_feature=student_model.linear(student_feature)
            student_outputs=student_model(inputs)
            teacher_feature=nn.functional.softmax(teacher_feature/T, dim=1)
            student_feature=nn.functional.log_softmax(student_feature/T, dim=1)
            teacher_outputs=nn.functional.softmax(teacher_outputs/T, dim=1)
            student_outputs=nn.functional.log_softmax(student_outputs/T, dim=1)
            loss=(T**2)*criterion(student_outputs, teacher_outputs) + (T**2)*criterion(student_feature, teacher_feature)/2
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_acc = test(student_model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

def orthogonal_feature_logit(teacher_model, student_model, dataloader, epochs, criterion, optimizer, device, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_feature = teacher_model.forward_feature(inputs) 
                b, c, h, w=teacher_feature.shape
                teacher_feature=teacher_feature.view(b, c, h*w).mean(-1)
                teacher_outputs = teacher_model(inputs)
            T=2.0
            student_feature = student_model.forward_feature(inputs)
            b, c, h, w=student_feature.shape
            student_feature=student_feature.view(b, c, h*w).mean(-1)
            student_feature=student_model.orthogonal(student_feature)
            student_outputs=student_model(inputs)
            teacher_feature=nn.functional.softmax(teacher_feature/T, dim=1)
            student_feature=nn.functional.log_softmax(student_feature/T, dim=1)
            teacher_outputs=nn.functional.softmax(teacher_outputs/T, dim=1)
            student_outputs=nn.functional.log_softmax(student_outputs/T, dim=1)
            loss=(T**2)*criterion(student_outputs, teacher_outputs) + (T**2)*criterion(student_feature, teacher_feature)/2
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_acc = test(student_model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")




In [47]:
#student_model=CNNCifar().to(device)
logit_distill(teacher_model, student_model, dataloader_test_global, 50, nn.KLDivLoss(reduction='batchmean'), torch.optim.Adam(student_model.parameters(), lr=0.001), device, temperature=2.0, alpha=0.5)

Epoch 1/50, Loss: 1.805564669868614, acc: 0.7637
Epoch 2/50, Loss: 1.7319887687888327, acc: 0.7616
Epoch 3/50, Loss: 1.6939975513687617, acc: 0.758
Epoch 4/50, Loss: 1.6815914877607852, acc: 0.7617
Epoch 5/50, Loss: 1.666155357904072, acc: 0.7641
Epoch 6/50, Loss: 1.6313270258752606, acc: 0.7701
Epoch 7/50, Loss: 1.5688068174485919, acc: 0.7685
Epoch 8/50, Loss: 1.5256271558471872, acc: 0.7733
Epoch 9/50, Loss: 1.5056571226708497, acc: 0.7805
Epoch 10/50, Loss: 1.4792810444590412, acc: 0.7839
Epoch 11/50, Loss: 1.4928666034454032, acc: 0.7734
Epoch 12/50, Loss: 1.5759678496212899, acc: 0.787
Epoch 13/50, Loss: 1.6601371327532997, acc: 0.7991
Epoch 14/50, Loss: 1.642585719499407, acc: 0.7991
Epoch 15/50, Loss: 1.601826298085949, acc: 0.8035
Epoch 16/50, Loss: 1.5294582615547543, acc: 0.8153
Epoch 17/50, Loss: 1.4637841989722433, acc: 0.7976
Epoch 18/50, Loss: 1.4199930374758154, acc: 0.803
Epoch 19/50, Loss: 1.352782128355171, acc: 0.8207
Epoch 20/50, Loss: 1.2661211915031265, acc: 0.82

In [54]:
#student_model=CNNCifar().to(device)
orthogonal_feature_logit(teacher_model, student_model, dataloader_test_global, 20, nn.KLDivLoss(reduction='batchmean'), torch.optim.Adam(student_model.parameters(), lr=0.001), device, temperature=2.0, alpha=0.5)

Epoch 1/20, Loss: 0.6491282842581785, acc: 0.8586
Epoch 2/20, Loss: 0.6230390667915344, acc: 0.8674
Epoch 3/20, Loss: 0.6303821473936492, acc: 0.8564
Epoch 4/20, Loss: 0.6190885980672474, acc: 0.8502
Epoch 5/20, Loss: 0.5959386535083191, acc: 0.8494
Epoch 6/20, Loss: 0.6252081926110424, acc: 0.8554
Epoch 7/20, Loss: 0.638965550102765, acc: 0.8555
Epoch 8/20, Loss: 0.6946149558960637, acc: 0.8477
Epoch 9/20, Loss: 0.7039364338675632, acc: 0.8485
Epoch 10/20, Loss: 0.6947954512095149, acc: 0.8615
Epoch 11/20, Loss: 0.6857264796389809, acc: 0.8557
Epoch 12/20, Loss: 0.6562785517565811, acc: 0.7911
Epoch 13/20, Loss: 0.6995280151125751, acc: 0.8321
Epoch 14/20, Loss: 0.7234813631335392, acc: 0.8109
Epoch 15/20, Loss: 0.7122362146649179, acc: 0.7994
Epoch 16/20, Loss: 0.6693531018269213, acc: 0.8006
Epoch 17/20, Loss: 0.6834142453308347, acc: 0.8427
Epoch 18/20, Loss: 0.6744329091868823, acc: 0.86
Epoch 19/20, Loss: 0.6850139351585244, acc: 0.8511
Epoch 20/20, Loss: 0.6588590205470218, acc:

In [52]:
#student_model=CNNCifar().to(device )
linear_feature_logit(teacher_model, student_model, dataloader_test_global, 50, nn.KLDivLoss(reduction='batchmean'), torch.optim.Adam(student_model.parameters(), lr=0.001), device, temperature=2.0, alpha=0.5)

Epoch 1/50, Loss: 1.6040977605536013, acc: 0.7746
Epoch 2/50, Loss: 1.4876820956227146, acc: 0.781
Epoch 3/50, Loss: 1.4266639442383489, acc: 0.7866
Epoch 4/50, Loss: 1.3946480041817775, acc: 0.791
Epoch 5/50, Loss: 1.3790809270701831, acc: 0.7876
Epoch 6/50, Loss: 1.4334243166295788, acc: 0.7434
Epoch 7/50, Loss: 1.5278609402572052, acc: 0.7387
Epoch 8/50, Loss: 1.5920188702359985, acc: 0.7475
Epoch 9/50, Loss: 1.5755640223056455, acc: 0.7292
Epoch 10/50, Loss: 1.561286427552187, acc: 0.6944
Epoch 11/50, Loss: 1.5968164363993873, acc: 0.747
Epoch 12/50, Loss: 1.623670150962057, acc: 0.7781
Epoch 13/50, Loss: 1.5176632019537915, acc: 0.7845
Epoch 14/50, Loss: 1.4688349015350584, acc: 0.8032
Epoch 15/50, Loss: 1.4361648619929446, acc: 0.8126
Epoch 16/50, Loss: 1.3642199314847778, acc: 0.8145
Epoch 17/50, Loss: 1.283259337838692, acc: 0.8125
Epoch 18/50, Loss: 1.2238493595696702, acc: 0.8215
Epoch 19/50, Loss: 1.1642638724061507, acc: 0.8254
Epoch 20/50, Loss: 1.1402955296673352, acc: 0.

In [14]:
# Define the training function for the teacher model
def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train_teacher(model, dataloader, epochs, criterion, optimizer, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        test_acc = test(model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

# Define the distillation function
def distill(teacher_model, student_model, dataloader, epochs, criterion, optimizer, device, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    criterion1 = nn.MSELoss()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                teacher_features = teacher_model.forward_feature(inputs)
                teacher_outputs = teacher_model.forward_head(teacher_features) 
            T=2.0
            student_features = student_model.forward_feature(inputs)
            student_outputs = student_model.forward_head(student_features)
            student_features = student_model.orthogonal_projector(student_features)
            student_features=nn.functional.log_softmax(student_features/T, dim=1)
            teacher_features=nn.functional.softmax(teacher_features/T, dim=1)
            teacher_outputs=nn.functional.softmax(teacher_outputs/T, dim=1)
            student_outputs=nn.functional.log_softmax(student_outputs/T, dim=1)
            loss=(T**2)*criterion1(student_features, teacher_features)+ (T**2)*criterion(student_outputs, teacher_outputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        test_acc = test(student_model, dataloader_test_global, device)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}, acc: {test_acc}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move models to device
teacher_model.to(device)
student_model.to(device)

# Define loss criterion and optimizers
criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss(reduction='batchmean')
teacher_optimizer = torch.optim.Adam(teacher_model.parameters(), lr=0.001)
student_optimizer = torch.optim.Adam(student_model.parameters(), lr=0.001)

# Training and distillation process
communication_rounds = 10
teacher_epochs = 10
distill_epochs = 10

for round in range(communication_rounds):
    print(f"Communication Round {round+1}/{communication_rounds}")
    # Train teacher model
    start=time.time()
    train_teacher(teacher_model, dataloader_train_global, teacher_epochs, criterion1, teacher_optimizer, device)
    print(f"Teacher training time: {time.time()-start}")
    # Distill knowledge to student model
    distill(teacher_model, student_model, dataloader_test_global, distill_epochs, criterion2, student_optimizer, device)
    print(f"Distillation time: {time.time()-start}")
    

Communication Round 1/10
Epoch 1/10, Loss: 1.2938785170350233, acc: 0.6189
Epoch 2/10, Loss: 0.9349009609588271, acc: 0.6871
Epoch 3/10, Loss: 0.8005198968950745, acc: 0.712
Epoch 4/10, Loss: 0.7175334631024725, acc: 0.7507
Epoch 5/10, Loss: 0.6621169444087827, acc: 0.7662
Epoch 6/10, Loss: 0.6160329605459862, acc: 0.7723
Epoch 7/10, Loss: 0.5832617718088048, acc: 0.7877
Epoch 8/10, Loss: 0.5479911274617285, acc: 0.7965
Epoch 9/10, Loss: 0.5205248168972142, acc: 0.7898
Epoch 10/10, Loss: 0.49528206301772076, acc: 0.798
Teacher training time: 127.73913908004761
Epoch 1/10, Loss: 141.07748046102404, acc: 0.5244
Epoch 2/10, Loss: 139.95803736433197, acc: 0.6249
Epoch 3/10, Loss: 139.47818746446055, acc: 0.6865
Epoch 4/10, Loss: 139.1932224321969, acc: 0.7205
Epoch 5/10, Loss: 139.0360431912579, acc: 0.7367
Epoch 6/10, Loss: 138.98013189774525, acc: 0.7385
Epoch 7/10, Loss: 138.94964947277987, acc: 0.7414
Epoch 8/10, Loss: 138.91169506990457, acc: 0.7489
Epoch 9/10, Loss: 138.8757760736006

KeyboardInterrupt: 