In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torchvision import models
from torchvision.models.resnet import resnet34
from torchvision.transforms.transforms import Resize
import torch.nn.functional as F
import json
import numpy as np
from PIL import Image

if torch.cuda.is_available():
        device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
])
#采用自带的Cifar100
trainset = torchvision.datasets.CIFAR100(root='./data_CIFAR100', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
 
testset = torchvision.datasets.CIFAR100(root='./data_CIFAR100', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)


### 在数据集中加入weight

In [5]:
from torchvision import datasets
class CIFAR100(datasets.CIFAR100):
    def __init__(self, root, indexs, influence_weight, train=True,
                 transform=torchvision.transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]),
                 download=True):
        super().__init__(root, train=train,
                         transform=transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
            self.index = indexs
            self.weight = influence_weight[indexs]


    def __getitem__(self, index):
        img, targets,indexs,weight = self.data[index], \
                                    self.targets[index], \
                                    self.index[index],\
                                    self.weight[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, targets,indexs,weight

In [6]:
all_idx = np.array(range(len(trainset.targets)))
default_weight = np.ones(len(trainset.targets))

train_dataset = CIFAR100(root='./data_CIFAR100',
     indexs = all_idx,influence_weight=default_weight, train=True
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=False)


Files already downloaded and verified


### 训练 Teacher Model

In [11]:
teacher_model=models.resnet50(pretrained=True)
    
fc_inputs = teacher_model.fc.in_features #获得fc特征层的输入
teacher_model.fc = nn.Sequential(         #重新定义特征层，根据需要可以添加自己想要的Linear层
    nn.Linear(fc_inputs, 100),  #多加几层都没关系 
)

In [12]:
 
 # Training
def train_withif(epoch,batch_size):
    print('\nTrain: Epoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets,indexes,weights) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        weights = weights.to(device)
        optimizer.zero_grad()
        outputs = net(torch.squeeze(inputs, 1))
        loss = torch.sum(criterion_train(outputs, targets)*weights)/batch_size
        loss.backward()
        optimizer.step()
 
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx%100 ==0:
            print('epoch: %d' % epoch, '| Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    # wandb.log({
    #     "Train_Loss": train_loss,
    #     "Train_Acc": 100. * correct / total
    # })
    
def test(epoch):
    print('\nTest: Epoch: %d' % epoch)
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(torch.squeeze(inputs, 1))
            loss = criterion_test(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print('epoch: %d'% epoch, 'Acc: %.3f%% (%d/%d)'
        %  (100.*correct/total, correct, total))
    # wandb.log({
    #     "Test_Acc": 100. * correct / total
    # })

    

In [13]:
net = teacher_model.to(device)
criterion_train = nn.CrossEntropyLoss(reduction='none')
criterion_test = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) #减小 lr

epoches = 5 #for example

for epoch in range(epoches):
    train_withif(epoch,batch_size=64)
    test(epoch)




Train: Epoch: 0
epoch: 0 | Loss: 4.609 | Acc: 1.562% (1/64)
epoch: 0 | Loss: 2.805 | Acc: 33.029% (2135/6464)
epoch: 0 | Loss: 2.104 | Acc: 46.261% (5951/12864)
epoch: 0 | Loss: 1.839 | Acc: 51.557% (9932/19264)
epoch: 0 | Loss: 1.681 | Acc: 54.941% (14100/25664)
epoch: 0 | Loss: 1.581 | Acc: 57.117% (18314/32064)
epoch: 0 | Loss: 1.508 | Acc: 58.707% (22581/38464)
epoch: 0 | Loss: 1.452 | Acc: 59.912% (26879/44864)

Test: Epoch: 0
epoch: 0 Acc: 67.130% (6713/10000)

Train: Epoch: 1
epoch: 1 | Loss: 0.991 | Acc: 68.750% (44/64)
epoch: 1 | Loss: 0.954 | Acc: 71.674% (4633/6464)
epoch: 1 | Loss: 0.887 | Acc: 73.212% (9418/12864)
epoch: 1 | Loss: 0.855 | Acc: 73.967% (14249/19264)
epoch: 1 | Loss: 0.829 | Acc: 74.719% (19176/25664)
epoch: 1 | Loss: 0.810 | Acc: 75.299% (24144/32064)
epoch: 1 | Loss: 0.796 | Acc: 75.764% (29142/38464)
epoch: 1 | Loss: 0.787 | Acc: 76.052% (34120/44864)

Test: Epoch: 1
epoch: 1 Acc: 72.650% (7265/10000)

Train: Epoch: 2
epoch: 2 | Loss: 0.633 | Acc: 76.562

In [14]:
teacher_save_path='./resnet50_cifar100_epoch5_withoutif.pkl'
torch.save(net.state_dict(),teacher_save_path) 

### 训练 Student Model

In [16]:
def TS_train_1_withif_2(epoch,teacher_model,student_model,trainloader,soft_loss,optimizer,batch_size):
    print('\nTrain: Epoch: %d' % epoch)
    student_model.train()
    train_loss = 0
    correct = 0
    total = 0
    temp = 7
    alpha = 0.3
    
    for batch_idx, (inputs, targets,indexes,weights) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        weights = weights.to(device)

        with torch.no_grad():
            teacher_preds = teacher_model(torch.squeeze(inputs, 1))

        # 学生模型预测
        student_preds = student_model(torch.squeeze(inputs, 1))

        # 计算蒸馏后的预测结果及soft_loss
        distillation_loss = soft_loss(
            F.log_softmax(student_preds/temp, dim=1),#log_softmax
            F.softmax(teacher_preds/temp, dim=1)
        )
        #loss = distillation_loss*weights
        loss = torch.sum(torch.sum(distillation_loss,dim=1)*weights)/batch_size
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        train_loss += loss.item()
        _, predicted = student_preds.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if batch_idx%200 ==0:
            print(batch_idx+1,'/', len(trainloader),'epoch: %d' % epoch, '| Loss: %.3f | Acc: %.3f%% (%d/%d)'
                % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # wandb.log({
    #     "Train_Loss": train_loss,
    #     "Train_Acc": 100. * correct / total
    # })

def test(epoch):
    print('\nTest: Epoch: %d' % epoch)
    global best_acc
    student_model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = student_model(torch.squeeze(inputs, 1))
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    print('epoch: %d'% epoch, 'Acc: %.3f%% (%d/%d)'
        %  (100.*correct/total, correct, total))
    # wandb.log({
    #     "Test_Acc": 100. * correct / total
    # })



In [40]:
teacher_model=models.resnet50(pretrained=True)
    
fc_inputs = teacher_model.fc.in_features #获得fc特征层的输入
teacher_model.fc = nn.Sequential(         #重新定义特征层，根据需要可以添加自己想要的Linear层
    nn.Linear(fc_inputs, 100),  #多加几层都没关系 
)

teacher_model.load_state_dict(torch.load(teacher_save_path)) 
teacher_model = teacher_model.to(device)

###############################################################

student_model=models.mobilenet_v2(pretrained=True)

student_model.classifier = nn.Sequential( 
    #重新定义特征层，根据需要可以添加自己想要的Linear层
    nn.Dropout(p=0.2, inplace=False),
    nn.Linear(in_features=1280, out_features=100),  #多加几层都没关系
    #nn.LogSoftmax(dim=1)
)
 
student_model = student_model.to(device)

In [41]:
criterion = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction="none")
optimizer = optim.SGD(student_model.parameters(), lr=0.01,momentum=0.9)
batch_size=64
temp = 7

epoches=10 #for example

for epoch in range(epoches): 
    TS_train_1_withif_2(epoch,teacher_model,student_model,train_dataloader,soft_loss,optimizer,batch_size)
    test(epoch)





Train: Epoch: 0
1 / 782 epoch: 0 | Loss: 0.234 | Acc: 1.562% (1/64)
201 / 782 epoch: 0 | Loss: 0.234 | Acc: 2.651% (341/12864)
401 / 782 epoch: 0 | Loss: 0.221 | Acc: 4.095% (1051/25664)
601 / 782 epoch: 0 | Loss: 0.208 | Acc: 5.363% (2063/38464)

Test: Epoch: 0
epoch: 0 Acc: 13.160% (1316/10000)

Train: Epoch: 1
1 / 782 epoch: 1 | Loss: 0.153 | Acc: 10.938% (7/64)
201 / 782 epoch: 1 | Loss: 0.146 | Acc: 15.236% (1960/12864)
401 / 782 epoch: 1 | Loss: 0.139 | Acc: 16.852% (4325/25664)
601 / 782 epoch: 1 | Loss: 0.134 | Acc: 18.500% (7116/38464)

Test: Epoch: 1
epoch: 1 Acc: 26.680% (2668/10000)

Train: Epoch: 2
1 / 782 epoch: 2 | Loss: 0.110 | Acc: 23.438% (15/64)
201 / 782 epoch: 2 | Loss: 0.107 | Acc: 27.806% (3577/12864)
401 / 782 epoch: 2 | Loss: 0.104 | Acc: 28.538% (7324/25664)
601 / 782 epoch: 2 | Loss: 0.101 | Acc: 29.529% (11358/38464)

Test: Epoch: 2
epoch: 2 Acc: 34.650% (3465/10000)

Train: Epoch: 3
1 / 782 epoch: 3 | Loss: 0.087 | Acc: 35.938% (23/64)
201 / 782 epoch: 3 |

In [42]:
student_save_path='./Dis_resnet50(T)_mobilebetv2(S)_cifar100_epoch10_withif_1.pkl'
torch.save(student_model.state_dict(),student_save_path) 

In [None]:
# epoches=50

# for epoch in range(epoches): 
#     TS_train_1_withif_2(epoch,teacher_model,student_model,train_dataloader,soft_loss,optimizer,batch_size)
#     test(epoch)
#     if epoch==30:
#         save_path='./save_model_202312/Dis_resnet50(T)_vgg(S)_cifar100_epoch30_withif_1.pkl'
#         torch.save(student_model.state_dict(),save_path) 

# save_path='./save_model_202312/Dis_resnet50(T)_vgg(S)_cifar100_epoch50_withif_1.pkl'
# torch.save(student_model.state_dict(),save_path) 
