In [1]:
import matplotlib.pyplot as plt # for visualizing
import os
import torch
from torch import nn
import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
import math
import torch.nn as nn
from collections import OrderedDict

import torch.backends.cudnn as cudnn
import argparse


import numpy as np

print(torch.__version__)
print(torchvision.__version__)

2.8.0+cu126
0.23.0+cu126


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
origin_dir = '/content/drive/MyDrive/Colab_Notebooks/URP/VGG16_CIFAR100_pruning'
data_dir = origin_dir + '/data'

In [6]:

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),   # 데이터 증강: 랜덤 크롭
    transforms.RandomHorizontalFlip(),      # 좌우 반전
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
    #데이터 증강 기법을 자동으로 탐색 후 적용
    transforms.ToTensor(),                  # Tensor 변환
    transforms.Normalize((0.5071, 0.4867, 0.4408),
                         (0.2675, 0.2565, 0.2761))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408),
                         (0.2675, 0.2565, 0.2761)),
])

trainset = torchvision.datasets.CIFAR100(
    root = data_dir, train=True, download=True, transform=transform_train
)

testset = torchvision.datasets.CIFAR100(
    root= data_dir, train=False, download=True, transform=transform_test
)


In [7]:
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=128, shuffle=False, num_workers=2
)

trainloader,testloader

(<torch.utils.data.dataloader.DataLoader at 0x7f06b59efce0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f06b5a1f980>)

Model

In [8]:
import math
import torch.nn as nn
from collections import OrderedDict

norm_mean, norm_var = 0.0, 1.0

defaultcfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 512]
relucfg = [2, 6, 9, 13, 16, 19, 23, 26, 29, 33, 36, 39,42]
convcfg = [0, 3, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37]


class VGG(nn.Module):
    def __init__(self, num_classes=100, init_weights=True, cfg=None):
        super(VGG, self).__init__()
        self.features = nn.Sequential()

        if cfg is None:
            cfg = defaultcfg

        self.relucfg = relucfg
        self.covcfg = convcfg
        self.features = self.make_layers(cfg[:-1], True)
        self.classifier = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(cfg[-2], cfg[-1])),
            ('norm1', nn.BatchNorm1d(cfg[-1])),
            ('relu1', nn.ReLU(inplace=True)),
            ('linear2', nn.Linear(cfg[-1], num_classes)),
        ]))

        if init_weights:            #weight 초기화
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=True):
        layers = nn.Sequential()
        in_channels = 3
        cnt = 0
        for i, v in enumerate(cfg):
            if v == 'M':
                layers.add_module('pool%d' % i, nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)

                layers.add_module('conv%d' % i, conv2d)
                layers.add_module('norm%d' % i, nn.BatchNorm2d(v))
                layers.add_module('relu%d' % i, nn.ReLU(inplace=True))
                in_channels = v
        return layers

    def forward(self, x):
        x = self.features(x)

        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

In [9]:
net = VGG()
net = net.to(device)

In [10]:
print('==> Building model..')
pre_train_model = origin_dir + '/original_model/model_epoch_120.pth'

# Load checkpoint.
print('==> Resuming from checkpoint..')
checkpoint = torch.load(pre_train_model, map_location=device)
net.load_state_dict(checkpoint)


criterion = nn.CrossEntropyLoss()
feature_result = torch.tensor(0.)
total = torch.tensor(0.)

==> Building model..
==> Resuming from checkpoint..


In [11]:
net.features

Sequential(
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu0): ReLU(inplace=True)
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace=True)
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (norm4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu4): ReLU(inplace=True)
  (pool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv6): Conv2d(12

In [12]:
def get_feature_hook(self, input, output):
    global feature_result
    global entropy
    global total
    a = output.shape[0]     #batch size
    b = output.shape[1]     #the number of channel
    c = torch.tensor([
        torch.linalg.matrix_rank(output[i,j,:,:]).item()  #행렬 rank 계산
        for i in range(a) for j in range(b)])             #batch siez * the number of channel

    #각 (배치, 채널)마다 2D rank 계산

    c = c.view(a, -1).float()
    c = c.sum(0)
    feature_result = feature_result * total + c #rank 누적
    total = total + a
    feature_result = feature_result / total     #평균 rank 계산

In [13]:
def test():
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    limit = 20        #20번 실행

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            if batch_idx >= limit:  # use the first 6 batches to estimate the rank.
               break
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print('batch_idx: ' , batch_idx, limit, '\nLoss: %.3f | Acc: %.3f%% (%d/%d)'
                % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))#'''


In [14]:
from timeit import default_timer as time

save_dir = origin_dir + '/rank_generation/'

for i, cov_id in enumerate(relucfg):
    cov_layer = net.features[cov_id]
    handler = cov_layer.register_forward_hook(get_feature_hook) #forward 호출 시 get_feature_hook 호출
    test()              #loss 계산 / rank계산 확인용
    handler.remove()

    np.save(save_dir+'/rank_conv' + str(i + 1) + '.npy', feature_result.numpy())
    print('saved the',str(i + 1),'layer')
    feature_result = torch.tensor(0.)
    total = torch.tensor(0.)

batch_idx:  0 20 
Loss: 0.280 | Acc: 94.531% (121/128)
batch_idx:  1 20 
Loss: 0.377 | Acc: 92.188% (236/256)
batch_idx:  2 20 
Loss: 0.437 | Acc: 91.146% (350/384)
batch_idx:  3 20 
Loss: 0.446 | Acc: 91.211% (467/512)
batch_idx:  4 20 
Loss: 0.455 | Acc: 91.250% (584/640)
batch_idx:  5 20 
Loss: 0.439 | Acc: 91.927% (706/768)
batch_idx:  6 20 
Loss: 0.445 | Acc: 91.853% (823/896)
batch_idx:  7 20 
Loss: 0.430 | Acc: 91.895% (941/1024)
batch_idx:  8 20 
Loss: 0.451 | Acc: 91.319% (1052/1152)
batch_idx:  9 20 
Loss: 0.444 | Acc: 91.484% (1171/1280)
batch_idx:  10 20 
Loss: 0.434 | Acc: 91.761% (1292/1408)
batch_idx:  11 20 
Loss: 0.430 | Acc: 91.927% (1412/1536)
batch_idx:  12 20 
Loss: 0.424 | Acc: 92.127% (1533/1664)
batch_idx:  13 20 
Loss: 0.432 | Acc: 91.853% (1646/1792)
batch_idx:  14 20 
Loss: 0.429 | Acc: 91.875% (1764/1920)
batch_idx:  15 20 
Loss: 0.436 | Acc: 91.650% (1877/2048)
batch_idx:  16 20 
Loss: 0.451 | Acc: 91.131% (1983/2176)
batch_idx:  17 20 
Loss: 0.447 | Acc: 9

각 relu 포인트에 register_forward_hook(get_feature_hook) 검
test에서 forward 실행될때마다 get_feature_hook 실행
limit횟수만큼 forward
rank 계산 누적 및 learning 평균 업데이트