# Setting

In [1]:
import pandas as pd
import torch
import random
import seaborn as sns
import os
import pickle
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import KFold

from torch.utils.data import DataLoader, Dataset, Subset
import torch.nn as nn
import torch.optim as optim

from transformers import BertTokenizer

import matplotlib.pyplot as plt
from perceiver import tokenize_data, CustomDataset, PerceiverBlock, Perceiver, CombinedModel

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
def seed_everything(seed):
    torch.manual_seed(seed) #torch를 거치는 모든 난수들의 생성순서를 고정한다
    torch.cuda.manual_seed(seed) #cuda를 사용하는 메소드들의 난수시드는 따로 고정해줘야한다 
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True #딥러닝에 특화된 CuDNN의 난수시드도 고정 
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed) #numpy를 사용할 경우 고정
    random.seed(seed) #파이썬 자체 모듈 random 모듈의 시드 고정
seed_everything(42)

## Import Data

# Models 

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LENGTH = 128

In [4]:
root_dir = '/home/jisoo/Perceiver/model/'
loader_dir = '/home/jisoo/Perceiver/loader/'

batch_size = 32

In [5]:
# class CustomDataset(Dataset):
#     def __init__(self, input_ids, labels):
#         self.input_ids = input_ids
#         self.labels = labels

#     def __len__(self):
#         return len(self.labels)

#     def __getitem__(self, idx):
#         return {
#             'input_ids': self.input_ids[idx],
#             'labels': self.labels[idx]
#         }

## Load Pretrained Model, Dataloader

### Import Model

In [6]:
input_models = []
valid_loaders = []
for i in range (6):
    text_model = torch.load(root_dir + f'text_model_{i+1}.pkl')
    input_models.append(text_model)
    print(f"Text model {i+1}번 불러오기 완료.")

  text_model = torch.load(root_dir + f'text_model_{i+1}.pkl')


Text model 1번 불러오기 완료.
Text model 2번 불러오기 완료.
Text model 3번 불러오기 완료.
Text model 4번 불러오기 완료.
Text model 5번 불러오기 완료.
Text model 6번 불러오기 완료.


In [7]:
for i in range(6):
    img_model = torch.load(root_dir + f'image_model_{i+1}.pkl')
    input_models.append(img_model)
    print(f"Image model {i}번 불러오기 완료.")

  img_model = torch.load(root_dir + f'image_model_{i+1}.pkl')


Image model 0번 불러오기 완료.
Image model 1번 불러오기 완료.
Image model 2번 불러오기 완료.
Image model 3번 불러오기 완료.
Image model 4번 불러오기 완료.
Image model 5번 불러오기 완료.


### Import Dataloader

주의: 현재 text 모달리티는 dataloader 자체가 저장되어있지만 image 모달리티는 데이터가 그대로 저장되어있어 Dataloader로 변환해주어야 합니다. \
일단 지금은 이대로 두지만 언젠가 에러나면 수정이 필요합니다. 

In [8]:
for i in range(6):
    with open(loader_dir+f'text_val_loader_{i+1}.pkl', 'rb') as f:
        loaded_valid_dataset = pickle.load(f)
    valid_loaders.append(loaded_valid_dataset)
    print(f"Text val. loader {i}번 불러오기 완료.")

Text val. loader 0번 불러오기 완료.
Text val. loader 1번 불러오기 완료.
Text val. loader 2번 불러오기 완료.
Text val. loader 3번 불러오기 완료.
Text val. loader 4번 불러오기 완료.
Text val. loader 5번 불러오기 완료.


In [9]:
for i in range(6):
    with open(loader_dir+f'image_val_loader_{i+1}.pkl', 'rb') as f:
        loaded_valid_dataset = pickle.load(f)

    valid_loader = DataLoader(loaded_valid_dataset, batch_size=batch_size, shuffle=False)
    valid_loaders.append(valid_loader)
    print(f"Image val. loader {i}번 불러오기 완료.")

Image val. loader 0번 불러오기 완료.
Image val. loader 1번 불러오기 완료.
Image val. loader 2번 불러오기 완료.
Image val. loader 3번 불러오기 완료.
Image val. loader 4번 불러오기 완료.
Image val. loader 5번 불러오기 완료.


## PackNet Models

In [10]:
class PackNet(nn.Module):
    def __init__(self, model):
        super(PackNet, self).__init__()
        self.model = model
        self.masks = {}
        self.current_task = None

    def set_task(self, task_id):
        self.current_task = task_id
        if task_id not in self.masks:
            self.masks[task_id] = {
                name: torch.ones_like(param, device=param.device)
                for name, param in self.model.named_parameters()
                if param.requires_grad
            }

    def prune(self, target_sparsity=0.2):
        if self.current_task is None:
            raise ValueError("Task must be set before pruning.")
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                mask = self.masks[self.current_task][name]
                threshold = torch.quantile(param.abs(), target_sparsity)
                mask[param.abs() < threshold] = 0
                self.masks[self.current_task][name] = mask

    def forward(self, input_ids, **kwargs):
        if self.current_task in self.masks:
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    if param.requires_grad:
                        param.data *= self.masks[self.current_task][name]
        return self.model(input_ids, **kwargs)

In [11]:
def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)  
            labels = batch['labels'].to(device)

            outputs = model(input_ids)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

In [12]:
def gradual_pruning(packnet_model, model_type, model_index, criterion, device, start_sparsity, end_sparsity, pruning_steps, checkpoint_dir, valid_loaders):
    model_path = f"{checkpoint_dir}/{model_type}_model_{model_index+1}_pruned.pkl"
    
    sparsity_increment = (end_sparsity - start_sparsity) / pruning_steps
    current_sparsity = start_sparsity
    
    test_loader = valid_loaders[model_index] if model_type == "text" else valid_loaders[model_index + 6]

    for step in range(pruning_steps):
        print(f"[{model_type.upper()} Model {model_index+1}] Pruning Step {step+1}/{pruning_steps} with sparsity {current_sparsity:.2f}")
        packnet_model.prune(target_sparsity=current_sparsity)
        
        current_sparsity += sparsity_increment
    
    with open(model_path, "wb") as f:
        pickle.dump({
            "model_state_dict": packnet_model.state_dict(),
            "masks": packnet_model.masks
        }, f)
    print(f"[{model_type.upper()} Model {model_index+1}] Pruned model saved at {model_path}.")
    
    test_loss, test_acc = eval_epoch(packnet_model, test_loader, criterion, device)
    print(f"[{model_type.upper()} Model {model_index+1}] Final Test Accuracy: {test_acc:.4f}")
    print("---------")
    
    return packnet_model

In [13]:
if __name__ == "__main__":
    start_sparsity = 0.05
    end_sparsity = 0.2
    pruning_steps = 5
    checkpoint_dir = "/home/jisoo/Perceiver/checkpoints_pruned"
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    criterion = nn.CrossEntropyLoss()

    print("Starting gradual pruning process...")
    text_models = input_models[:6]
    image_models = input_models[6:]

    pruned_text_models = []
    for i, model in enumerate(text_models):
        packnet_model = PackNet(model)
        packnet_model.set_task(f"text_task_{i+1}")
        pruned_text_models.append(gradual_pruning(packnet_model, "text", i, criterion, device, start_sparsity, end_sparsity, pruning_steps, checkpoint_dir, valid_loaders))
    
    pruned_image_models = []
    for i, model in enumerate(image_models):
        packnet_model = PackNet(model)
        packnet_model.set_task(f"image_task_{i+1}")
        #pruned_image_models.append(gradual_pruning(packnet_model, "image", i, criterion, device, start_sparsity, end_sparsity, pruning_steps, checkpoint_dir, valid_loaders))
    
    pruned_models = pruned_text_models + pruned_image_models
    print("Gradual pruning process finished for both text and image models.")


Starting gradual pruning process...
[TEXT Model 1] Pruning Step 1/5 with sparsity 0.05
[TEXT Model 1] Pruning Step 2/5 with sparsity 0.08
[TEXT Model 1] Pruning Step 3/5 with sparsity 0.11
[TEXT Model 1] Pruning Step 4/5 with sparsity 0.14
[TEXT Model 1] Pruning Step 5/5 with sparsity 0.17
[TEXT Model 1] Pruned model saved at /home/jisoo/Perceiver/checkpoints_pruned/text_model_1_pruned.pkl.
[TEXT Model 1] Final Test Accuracy: 0.8661
---------
[TEXT Model 2] Pruning Step 1/5 with sparsity 0.05
[TEXT Model 2] Pruning Step 2/5 with sparsity 0.08
[TEXT Model 2] Pruning Step 3/5 with sparsity 0.11
[TEXT Model 2] Pruning Step 4/5 with sparsity 0.14
[TEXT Model 2] Pruning Step 5/5 with sparsity 0.17
[TEXT Model 2] Pruned model saved at /home/jisoo/Perceiver/checkpoints_pruned/text_model_2_pruned.pkl.
[TEXT Model 2] Final Test Accuracy: 0.8057
---------
[TEXT Model 3] Pruning Step 1/5 with sparsity 0.05
[TEXT Model 3] Pruning Step 2/5 with sparsity 0.08
[TEXT Model 3] Pruning Step 3/5 with spa