In [239]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

from torchvision.models import resnet50
import torchvision.models as models
# from torchsummary import summary
import torchinfo as summary
import torch.nn.functional as F
from tqdm.auto import tqdm
from pathlib import Path
device = "cuda" if torch.cuda.is_available() else "cpu"
class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size =3 , stride=1, padding=0): 
        super(depthwise_separable_conv, self).__init__() 
        self.depthwise = nn.Conv2d(nin, nin, kernel_size, stride = stride, padding = padding, groups=nin) 
        self.pointwise = nn.Conv2d(nin, nout, 1) 

    def forward(self, x): 
        out = self.depthwise(x) 
        out = self.pointwise(out) 
        return out
class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            depthwise_separable_conv(3,64,padding="same"),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            depthwise_separable_conv(64,64),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.2),
            
            depthwise_separable_conv(64,128,padding="same"),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            depthwise_separable_conv(128,128),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.2),
            
            depthwise_separable_conv(128,256,padding="same"),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            depthwise_separable_conv(256,256),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Dropout(0.2),
            
            #  depthwise_separable_conv(64,256,padding="same"),
            # nn.BatchNorm2d(256),
            # nn.ReLU(),
            # depthwise_separable_conv(256,256,padding="same"),
            # nn.BatchNorm2d(256),
            # nn.ReLU(),
            depthwise_separable_conv(256,256,padding="same"),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            depthwise_separable_conv(256,512,padding="same"),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            nn.AdaptiveAvgPool2d((1,1)), # [128,1,1]
            nn.Flatten()
        )
        self.out = nn.Sequential(
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256,10)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.out(x)
        return x



In [240]:
def studentLossFn(teacher_pred, student_pred, y, T=3, alpha=0.4):
    if (alpha > 0):
        loss = F.kl_div(F.log_softmax(student_pred / T, dim=1), F.softmax(teacher_pred / T, dim=1), reduction='batchmean') * (T ** 2) * alpha + F.cross_entropy(student_pred, y) * (1 - alpha)
    else:
        loss = F.cross_entropy(student_pred, y)
    return loss

In [242]:
class ResNet(nn.Module):
        def __init__(self):
            super(ResNet, self).__init__()
            self.resnet50 = resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
            num_ftrs = self.resnet50.fc.in_features
            self.resnet50.fc = nn.Linear(num_ftrs, 10)

        def forward(self, x):
            x = self.resnet50.conv1(x)
            x = self.resnet50.bn1(x)
            x = self.resnet50.relu(x)
            x = self.resnet50.maxpool(x)

            x = self.resnet50.layer1(x)
            x = self.resnet50.layer2(x)
            x = self.resnet50.layer3(x)
            x = self.resnet50.layer4(x)

            x = self.resnet50.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.resnet50.fc(x)
            return x

teacher_model = ResNet()
weights_path = './resnet-50.pth'
checkpoint = torch.load(weights_path)
teacher_model.load_state_dict(checkpoint['model_state_dict'])
teacher_model = teacher_model.to(device)

In [243]:
train_transform = transforms.Compose(
        [transforms.Grayscale(num_output_channels=3),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])
test_transform = transforms.Compose(
        [transforms.Grayscale(num_output_channels=3),  # gray to 3 channel
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                                download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                            shuffle=True, num_workers=8)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                            download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                            shuffle=False, num_workers=0)

In [244]:
def train_step(student_model,teacher_model, train_dataloader,train_loss_function, opt, device):
    train_acc = 0
    train_loss = 0
    student_model.train()
    for x,y in train_dataloader:
        x,y = x.to(device), y.to(device)
        student_pred = student_model(x)
        with torch.no_grad():
            teacher_pred = teacher_model(x)
        loss = train_loss_function(teacher_pred,student_pred, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss += loss.item()
        train_acc += (torch.max(student_pred,1)[1]==y).sum().item()/len(student_pred)
    return train_loss/len(train_dataloader), train_acc/len(train_dataloader)

def test_step(student_model, test_dataloader, test_loss_function, device): 
    test_acc = 0
    test_loss = 0
    student_model.eval()
    with torch.inference_mode():
        for x,y in test_dataloader:
            x,y = x.to(device), y.to(device)
            pred_y = student_model(x)
            # print("pred_y:",pred_y.shape)
            loss = test_loss_function(pred_y,y) 
            test_loss += loss.item()
            test_acc += (torch.argmax(pred_y,1)==y).sum().item()/len(pred_y)
    return test_loss/len(test_dataloader), test_acc/len(test_dataloader)
def train(epochs, student_model, teacher_model,
          train_dataloader, test_dataloader, opt, 
          train_loss_function,test_loss_function,
          device, patience,model_name):
    last_loss = float("inf")
    cur = 0
    results ={
        "train_loss":[],
        "train_acc":[],
        "test_loss":[],
        "test_acc":[]
    }
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(student_model=student_model,
                                           teacher_model = teacher_model,
                                          train_dataloader=train_dataloader,
                                          train_loss_function=train_loss_function,
                                          opt=opt,
                                          device=device)
        test_loss, test_acc = test_step(student_model=student_model,
                                          test_dataloader=test_dataloader,
                                          test_loss_function=test_loss_function,
                                          device=device)
        if test_loss > last_loss:
            cur += 1
            print('trigger times:', cur)
            if cur >= patience:
                print("early stop !")
                return results
        else:
            cur = 0
        last_loss = test_loss
        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
        )

      # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)
        if (epoch+1)%10 == 0:
            MODEL_PATH = Path(model_name)
            MODEL_PATH.mkdir(parents=True, # create parent directories if needed
                             exist_ok=True # if models directory already exists, don't error
            )

            # Create model save path
            MODEL_NAME = f"model_{epoch+1}.pth"
            MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

            # Save the model state dict
            print(f"Saving model to: {MODEL_SAVE_PATH}")
            torch.save(obj=student_model.state_dict(), # only saving the state_dict() only saves the learned parameters
                       f=MODEL_SAVE_PATH)
    return results

# Prune model (假設已經train好準備要pruned model)

In [245]:
import torch.nn.utils.prune as prune
import torchinfo as summary

In [246]:
pruned_net = StudentModel()
checkpoint = torch.load("model_to_prune.pth")
pruned_net.load_state_dict(checkpoint)

<All keys matched successfully>

In [247]:
# Get the module to prune
param_to_prune=[]
for i,(name, module) in enumerate(pruned_net.named_modules()):
    if i>=2:
        if isinstance(module,depthwise_separable_conv):
            param_to_prune.append((module.depthwise,'weight'))
            param_to_prune.append((module.pointwise,'weight'))
        elif isinstance(module,nn.Conv2d)|isinstance(module,nn.Linear)|isinstance(module,nn.BatchNorm2d):
            param_to_prune.append((module,'weight'))
            

In [248]:
prune.global_unstructured(
    param_to_prune,
     pruning_method = prune.L1Unstructured,
     amount = 0.8
)

### _forward_pre_hooks 確保 model在做 forward前會mask 掉pruning weight

In [237]:
for p in param_to_prune:
    print(p[0],p[0]._forward_pre_hooks)
    print("-"*10)

Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=3) OrderedDict([(570, <torch.nn.utils.prune.PruningContainer object at 0x7efa4d7f9700>)])
----------
Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1)) OrderedDict([(571, <torch.nn.utils.prune.PruningContainer object at 0x7efa4d83b730>)])
----------
Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=3) OrderedDict([(570, <torch.nn.utils.prune.PruningContainer object at 0x7efa4d7f9700>)])
----------
Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1)) OrderedDict([(571, <torch.nn.utils.prune.PruningContainer object at 0x7efa4d83b730>)])
----------
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) OrderedDict([(572, <torch.nn.utils.prune.CustomFromMask object at 0x7efa4d83b790>)])
----------
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), groups=64) OrderedDict([(575, <torch.nn.utils.prune.PruningContainer object at 0x7efa4d84cfa0>)])
----------
Conv2d(64, 64, kernel_size=(1, 

### Retrain pruned model

In [None]:
EPOCH = 400
LR = 1e-4

In [None]:
opt = torch.optim.AdamW(pruned_net.parameters(),lr = LR )
pruned_net.to(device)
results = train(EPOCH,
      pruned_net,teacher_model,
      trainloader,testloader,
      train_loss_function = studentLossFn,
      test_loss_function = nn.CrossEntropyLoss(),
      opt = opt,
      device=device,
      patience = 5,
      model_name = "pruned_student_model")

### Testing retrained model accuracy

In [249]:
checkpoint = torch.load("pruned_student_model/model_170.pth")
pruned_net.load_state_dict(checkpoint)
pruned_net.to(device)
pruned_net.eval()
correct = 0
total = 0
pred_arr = []
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = pruned_net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        pred_arr.append(predicted.item())
accuracy = 100 * correct / total
print(f"Accuracy of the network on the {total} test images: {accuracy:.2f} %")

Accuracy of the network on the 10000 test images: 93.47 %


In [250]:
summary.summary(pruned_net.to(device))

Layer (type:depth-idx)                        Param #
StudentModel                                  --
├─Sequential: 1-1                             --
│    └─depthwise_separable_conv: 2-1          --
│    │    └─Conv2d: 3-1                       21
│    │    └─Conv2d: 3-2                       228
│    └─BatchNorm2d: 2-2                       128
│    └─ReLU: 2-3                              --
│    └─depthwise_separable_conv: 2-4          --
│    │    └─Conv2d: 3-3                       491
│    │    └─Conv2d: 3-4                       1,481
│    └─BatchNorm2d: 2-5                       128
│    └─ReLU: 2-6                              --
│    └─MaxPool2d: 2-7                         --
│    └─Dropout: 2-8                           --
│    └─depthwise_separable_conv: 2-9          --
│    │    └─Conv2d: 3-5                       476
│    │    └─Conv2d: 3-6                       2,731
│    └─BatchNorm2d: 2-10                      256
│    └─ReLU: 2-11                             --
│  

### weight_orig是沒有pruned過的參數，weight_mask中0的位置是要prune掉的參數，兩者的product就是pruned過的參數

In [227]:
list(pruned_net.named_parameters())[:2]

[('cnn.0.depthwise.bias',
  Parameter containing:
  tensor([0.1712, 0.1426, 0.0943], device='cuda:0', requires_grad=True)),
 ('cnn.0.depthwise.weight_orig',
  Parameter containing:
  tensor([[[[-0.2440,  0.1404,  0.1166],
            [-0.2298,  0.3388,  0.1480],
            [-0.0352,  0.1292, -0.0036]]],
  
  
          [[[ 0.0281, -0.3052, -0.1770],
            [ 0.0575, -0.0050,  0.1125],
            [ 0.1379,  0.0019, -0.0563]]],
  
  
          [[[-0.2921, -0.0949, -0.0833],
            [-0.2350, -0.1615,  0.0462],
            [ 0.0827,  0.0352, -0.0306]]]], device='cuda:0', requires_grad=True))]

In [229]:
list(pruned_net.named_buffers())[:1] #mask

[('cnn.0.depthwise.weight_mask',
  tensor([[[[1., 1., 1.],
            [1., 1., 1.],
            [0., 1., 0.]]],
  
  
          [[[0., 1., 1.],
            [0., 0., 1.],
            [1., 0., 1.]]],
  
  
          [[[1., 1., 1.],
            [1., 1., 0.],
            [1., 0., 0.]]]], device='cuda:0'))]

In [224]:
pruned_net.cnn[0].depthwise.weight #pruned完的參數

tensor([[[[-0.2440,  0.1404,  0.1166],
          [-0.2298,  0.3388,  0.1480],
          [-0.0000,  0.1292, -0.0000]]],


        [[[ 0.0000, -0.3052, -0.1770],
          [ 0.0000, -0.0000,  0.1125],
          [ 0.1379,  0.0000, -0.0563]]],


        [[[-0.2921, -0.0949, -0.0833],
          [-0.2350, -0.1615,  0.0000],
          [ 0.0827,  0.0000, -0.0000]]]], device='cuda:0')

# After removing "prune mask", the original parameter will be replaced with pruned parameter

In [254]:
for i,p in enumerate(param_to_prune):
    if p[0]._forward_pre_hooks:
        prune.remove(p[0],'weight')

### Also ,the number of parameters will be restored(97790->474216)

In [255]:
import torchinfo as summary
summary.summary(pruned_net)

Layer (type:depth-idx)                        Param #
StudentModel                                  --
├─Sequential: 1-1                             --
│    └─depthwise_separable_conv: 2-1          --
│    │    └─Conv2d: 3-1                       30
│    │    └─Conv2d: 3-2                       256
│    └─BatchNorm2d: 2-2                       128
│    └─ReLU: 2-3                              --
│    └─depthwise_separable_conv: 2-4          --
│    │    └─Conv2d: 3-3                       640
│    │    └─Conv2d: 3-4                       4,160
│    └─BatchNorm2d: 2-5                       128
│    └─ReLU: 2-6                              --
│    └─MaxPool2d: 2-7                         --
│    └─Dropout: 2-8                           --
│    └─depthwise_separable_conv: 2-9          --
│    │    └─Conv2d: 3-5                       640
│    │    └─Conv2d: 3-6                       8,320
│    └─BatchNorm2d: 2-10                      256
│    └─ReLU: 2-11                             --
│  