In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torchvision import datasets, transforms, models
import numpy as np 
import os
import joblib
import torch.nn.utils.prune as prune
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize(224),  # Resize to 224x224
    transforms.ToTensor(),   # Convert image to PyTorch Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet stats
])

train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)


In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
## utils 에넣기
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# For GPU
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
import joblib
import os
##args path에 넣기
saved_path= '/data/ephemeral/home/nathan/saved'
model_name=os.path.join(saved_path,'resnet18.joblib')
model = joblib.load(model_name)

In [None]:
# args 와 criterion에 넣기
# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [None]:
# train 풀더에 넣고 모듈화진행 
def train(model,train_loader,num_epochs=10):
    model.train()
    for epochs in range(num_epochs):
        model.train()
        running_loss=0.0
        for images, labels in tqdm(train_loader):
            images,labels= images.to(device),labels.to(device)

            output=model(images)
            loss = criterion(output,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss+=loss.item()

        print(f'Epochs: {epochs+1}/{num_epochs} Training loss: {running_loss/len(train_loader)}')
        
def test(model,test_loader,howmany):
    model.eval()
    s=time.time()
    with torch.no_grad():
        total=0
        correct=0
        for images,labels in tqdm(test_loader):
            images,labels= images.to(device),labels.to(device)
            output=model(images)
            _,predicted = torch.max(output,dim=1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum().item()
        
        accuracy=100*correct/total
        e=time.time()
        print(f'Accuracy: {accuracy}%, Forward Time: {e - s:.2f}s, pruned_channel: {howmany}')
        get_model_memory_usage(model)

def get_model_memory_usage(model):
    total_params = 0
    total_memory = 0

    for name, param in model.named_parameters():
        if param.requires_grad:
            total_params += param.numel()
            total_memory += param.numel() * param.element_size()  # Bytes

    print(f"Total Parameters: {total_params}")
    print(f"Memory Usage for Parameters: {total_memory / 1e6:.2f} MB")  # Convert to MB

### Structured pruning for CNN 은 2가지로 구분된다
1. CNN (연산속도)-> 학습된 모델에 sentiment analysis 요구 (Filter, )
2. FN (파라미터) 

In [None]:
#sensitivity analysis
sensitivity_layer={}
for name,module in model.named_modules():
    if isinstance(module,nn.Conv2d):
        L1_weight= module.weight.data.cpu().numpy()
        L1_weight=L1_weight.reshape(L1_weight.shape[0],-1)
        L1_weight=np.sort(np.sum(np.abs(L1_weight),axis=1))[::-1]
        L1_weight=L1_weight/L1_weight[0]
        #L2_weight = torch.sqrt(torch.sum(module.weight,dim=(1,2,3)))
        sensitivity_layer[name]=L1_weight
        
    

### sensitivity CNN 시각화

In [None]:
import matplotlib.pyplot as plt
colors = ['r', 'g', 'b', 'k', 'y', 'm', 'c']
lines = ['-', '--', '-.']


In [None]:
plt.figure(figsize=(7,5))
count=0
for key,sensitivity in sensitivity_layer.items():
    line_style=colors[count%len(colors)]+lines[count//len(colors)]
    x=np.linspace(0,100,num=sensitivity.shape[0])
    y=sensitivity
    count+=1
    print(count)
    plt.plot(x,y,line_style,label='conv %d'%count)
plt.ylabel("normalized abs sum of filter weight")
plt.xlabel("filter index / # filters (%)")
plt.legend(loc='upper right')
plt.xlim([0, 140])
plt.grid()
plt.show()        

In [None]:
print(model)

### prune 하기

In [None]:
max_ratio=0.9
step_ratio=8

In [None]:
idx2name_module={}
i=0
for name,module in model.named_modules():
    if isinstance(module,nn.Conv2d) and 'downsample' not in name:
        idx2name_module[i]=(name,module)
        i+=1
    elif isinstance(module,nn.BatchNorm2d) and 'downsample' not in name:
        idx2name_module[i]=(name,module)
        i+=1
    elif isinstance(module,nn.Linear):
        idx2name_module[i]=(name,module)
        i+=1

In [None]:
from prune_function import *
from prune import *
idx=0
for name,module in model.named_modules():
    if isinstance(module,nn.Conv2d) and 'downsample' not in name:
        step=np.linspace(0,int(module.out_channels*max_ratio),step_ratio,dtype=int)
        steps=step[1:]-step[:-1]
        # steps는 얼마만큼의 filter를 제거할꺼인지 정함.
        for i in range(len(steps)//2): 
            # 매번 필터를 제거하는양이 달라서 network부름
            network=joblib.load(model_name)
            num_channel=module.out_channels- sum(steps[:i+1])
            print(name,sum(steps[:i+1]))
            network=prune_step(network,name,num_channel,idx2name_module,index=idx).to(device)
            print("-*-"*10 + "\n\tPrune network\n" + "-*-"*10)
            print(network)
            
            network_name_v='resenet'+'_'+ name +'_'+str(sum(steps[:i+1]))+'.joblib'
            network_name=os.path.join(saved_path,network_name_v)

            #joblib.dump(network,network_name)
            test(network,test_loader,sum(steps[:i+1]))
        idx+=2
                
         

In [None]:
idx=12

In [None]:
from prune_function import *
from prune import *
name,module=idx2name_module[12]
step=np.linspace(0,int(module.out_channels*max_ratio),step_ratio,dtype=int)
steps=step[1:]-step[:-1]
# steps는 얼마만큼의 filter를 제거할꺼인지 정함.
for i in range(len(steps)//2): 
    # 매번 필터를 제거하는양이 달라서 network부름
    network=joblib.load(model_name).to('cpu')
    num_channel=module.out_channels- sum(steps[:i+1])
    print(name,sum(steps[:i+1]))
    network=prune_step(network,name,num_channel,idx2name_module,index=idx).to(device)
    print("-*-"*10 + "\n\tPrune network\n" + "-*-"*10)
    print(network)
    
    network_name_v='resenet'+'_'+ name +'_'+str(sum(steps[:i+1]))+'.joblib'
    network_name=os.path.join(saved_path,network_name_v)

    #joblib.dump(network,network_name)
    test(network,test_loader,sum(steps[:i+1]))


