# Setting

In [1]:
import pandas as pd
import torch
import random
import seaborn as sns
import os
import pickle
import numpy as np

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import KFold

from torch.utils.data import DataLoader, Dataset, Subset
import torch.nn as nn
import torch.optim as optim

from transformers import BertTokenizer

import matplotlib.pyplot as plt
from perceiver import tokenize_data, CustomDataset, PerceiverBlock, Perceiver, CombinedModel

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
def seed_everything(seed):
    torch.manual_seed(seed) #torch를 거치는 모든 난수들의 생성순서를 고정한다
    torch.cuda.manual_seed(seed) #cuda를 사용하는 메소드들의 난수시드는 따로 고정해줘야한다 
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True #딥러닝에 특화된 CuDNN의 난수시드도 고정 
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed) #numpy를 사용할 경우 고정
    random.seed(seed) #파이썬 자체 모듈 random 모듈의 시드 고정
seed_everything(42)

## Import Data

In [3]:
file_path = '/home/jisoo/n24news/n24news/captions_and_labels.csv'
data = pd.read_csv(file_path)

groups = [
    ['Opinion', 'Food', 'Movies'],
    ['Art & Design', 'Science', 'Fashion & Style'],
    ['Television', 'Sports', 'Style'],
    ['Music', 'Health', 'Dance'],
    ['Real Estate', 'Books', 'Media'],
    ['Travel', 'Theater', 'Technology']
]

output_paths = []
for i, group_labels in enumerate(groups, 1):
    group_data = data[data['Label'].isin(group_labels)]
    output_path = f'/home/jisoo/n24news/n24news/regroup_{i}.csv'
    group_data.to_csv(output_path, index=False)
    output_paths.append(output_path)

# Models 

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LENGTH = 128

In [5]:
root_dir = '/home/Minju/Perceiver/model/'
loader_dir = '/home/Minju/Perceiver/loader/'

batch_size = 32

In [6]:
# class CustomDataset(Dataset):
#     def __init__(self, input_ids, labels):
#         self.input_ids = input_ids
#         self.labels = labels

#     def __len__(self):
#         return len(self.labels)

#     def __getitem__(self, idx):
#         return {
#             'input_ids': self.input_ids[idx],
#             'labels': self.labels[idx]
#         }

## Load Pretrained Model, Dataloader

### Import Model

In [7]:
input_models = []
valid_loaders = []
for i in range (6):
    text_model = torch.load(root_dir + f'text_model_{i+1}.pkl')
    input_models.append(text_model)
    print(f"Text model {i+1}번 불러오기 완료.")

  text_model = torch.load(root_dir + f'text_model_{i+1}.pkl')


Text model 1번 불러오기 완료.
Text model 2번 불러오기 완료.
Text model 3번 불러오기 완료.
Text model 4번 불러오기 완료.
Text model 5번 불러오기 완료.
Text model 6번 불러오기 완료.


In [8]:
for i in range(6):
    img_model = torch.load(root_dir + f'image_model_{i+1}.pkl')
    input_models.append(img_model)
    print(f"Image model {i}번 불러오기 완료.")

Image model 0번 불러오기 완료.


  img_model = torch.load(root_dir + f'image_model_{i+1}.pkl')


Image model 1번 불러오기 완료.
Image model 2번 불러오기 완료.
Image model 3번 불러오기 완료.
Image model 4번 불러오기 완료.
Image model 5번 불러오기 완료.


### Import Dataloader

주의: 현재 text 모달리티는 dataloader 자체가 저장되어있지만 image 모달리티는 데이터가 그대로 저장되어있어 Dataloader로 변환해주어야 합니다. \
일단 지금은 이대로 두지만 언젠가 에러나면 수정이 필요합니다. 

In [9]:
for i in range(6):
    with open(loader_dir+f'text_val_loader_{i+1}.pkl', 'rb') as f:
        loaded_valid_dataset = pickle.load(f)
    valid_loaders.append(loaded_valid_dataset)
    print(f"Text val. loader {i}번 불러오기 완료.")

Text val. loader 0번 불러오기 완료.
Text val. loader 1번 불러오기 완료.
Text val. loader 2번 불러오기 완료.
Text val. loader 3번 불러오기 완료.
Text val. loader 4번 불러오기 완료.
Text val. loader 5번 불러오기 완료.


In [10]:
for i in range(6):
    with open(loader_dir+f'image_val_loader_{i+1}.pkl', 'rb') as f:
        loaded_valid_dataset = pickle.load(f)

    valid_loader = DataLoader(loaded_valid_dataset, batch_size=batch_size, shuffle=False)
    valid_loaders.append(valid_loader)
    print(f"Image val. loader {i}번 불러오기 완료.")

Image val. loader 0번 불러오기 완료.
Image val. loader 1번 불러오기 완료.
Image val. loader 2번 불러오기 완료.
Image val. loader 3번 불러오기 완료.
Image val. loader 4번 불러오기 완료.
Image val. loader 5번 불러오기 완료.


## PackNet Models

In [11]:
class PackNet(nn.Module):
    def __init__(self, model):
        super(PackNet, self).__init__()
        self.model = model
        self.masks = {}
        self.current_task = None

    def set_task(self, task_id):
        self.current_task = task_id
        if task_id not in self.masks:
            self.masks[task_id] = {
                name: torch.ones_like(param, device=param.device)
                for name, param in self.model.named_parameters()
                if param.requires_grad
            }

    def prune(self, sparsity=0.2):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                mask = self.masks[self.current_task][name]
                threshold = torch.quantile(param.abs(), sparsity)
                mask[param.abs() < threshold] = 0
                self.masks[self.current_task][name] = mask

    def forward(self, input_ids, **kwargs):
        if self.current_task in self.masks:
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    if param.requires_grad:
                        param.data *= self.masks[self.current_task][name]
        return self.model(input_ids, **kwargs)

In [12]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return total_loss / len(dataloader), correct / total

In [13]:
def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return total_loss / len(dataloader), correct / total

In [14]:
EPOCHS = 20
BATCH_SIZE = 32
K_FOLDS = 5
EMBED_DIM = 128  
LATENT_DIM = 64
LATENT_SIZE = 64
NUM_BLOCKS = 4

In [15]:
def apply_pruning_with_intervals(packnet_model, test_loader, criterion, device, start_sparsity, end_sparsity, pruning_ratio):
    
    current_sparsity = start_sparsity
    while current_sparsity <= end_sparsity:
        print(f"Applying pruning with sparsity: {current_sparsity:.2f}")
        packnet_model.prune(sparsity=current_sparsity)

        # Evaluate after pruning if test_loader is provided
        if test_loader is not None and criterion is not None:
            print("Evaluating after pruning...")
            pruned_test_loss, pruned_test_acc = eval_epoch(packnet_model, test_loader, criterion, device)
            print(f"Pruned Test Loss: {pruned_test_loss:.4f}, Test Accuracy: {pruned_test_acc:.4f}")
        else:
            print(f"Skipping evaluation as 'test_loader' or 'criterion' is None.")

        current_sparsity += pruning_ratio

In [16]:
results = []
all_learning_curves = []

for idx, group_file in enumerate(output_paths, start=1):
    print(f"\nGroup {idx} 처리 중...")

    df = pd.read_csv(group_file)
    label_encoder = LabelEncoder()
    df['Label'] = label_encoder.fit_transform(df['Label'])
    num_classes = len(label_encoder.classes_)

    input_ids, attention_masks = tokenize_data(df)
    labels = torch.tensor(df['Label'].values)

    dataset = CustomDataset(input_ids, attention_masks, labels)
    kfold = KFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

    fold_results = []
    fold_learning_curves = []

    averaged_state_dict = None

    for fold, (train_idx, test_idx) in enumerate(kfold.split(dataset), start=1):
        print(f"\n  Fold {fold}/{K_FOLDS} 처리 중...")

        train_subset = Subset(dataset, train_idx)
        test_subset = Subset(dataset, test_idx)

        train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False)

        # Perceiver 모델 초기화
        perceiver = Perceiver(
            input_dim=EMBED_DIM,
            latent_dim=LATENT_DIM,
            latent_size=LATENT_SIZE,
            num_classes=num_classes,
            num_blocks=NUM_BLOCKS,
            self_attn_layers_per_block=1
        )

        # CombinedModel 초기화
        combined_model = CombinedModel(
            vocab_size=tokenizer.vocab_size,
            embed_dim=EMBED_DIM,
            perceiver_model=perceiver
        )

        # PackNet
        packnet_model = PackNet(combined_model)
        packnet_model.to(device)
        packnet_model.set_task(f"task_{idx}_{fold}")

        optimizer = optim.Adam(packnet_model.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
        criterion = nn.CrossEntropyLoss()

        train_losses, test_losses = [], []
        train_accuracies, test_accuracies = [], []

        # Pruning 이전 성능 평가
        print("Pruning 이전 성능:")
        initial_test_loss, initial_test_acc = eval_epoch(packnet_model, test_loader, criterion, device)
        print(f"  Test Loss: {initial_test_loss:.4f}, Test Accuracy: {initial_test_acc:.4f}")

        for epoch in range(EPOCHS):
            train_loss, train_acc = train_epoch(packnet_model, train_loader, criterion, optimizer, device)
            test_loss, test_acc = eval_epoch(packnet_model, test_loader, criterion, device)

            train_losses.append(train_loss)
            test_losses.append(test_loss)
            train_accuracies.append(train_acc)
            test_accuracies.append(test_acc)

            scheduler.step()

            if (epoch + 1) % 5 == 0 or epoch == 0:
                print(f"epoch {epoch+1}/{EPOCHS}: train loss {train_loss:.4f}, train acc {train_acc:.4f}")
                print(f"                         test loss {test_loss:.4f}, test acc {test_acc:.4f}")

        # Fold 모델 상태 저장 및 평균 계산
        state_dict = packnet_model.state_dict()
        if averaged_state_dict is None:
            averaged_state_dict = {key: val.clone() for key, val in state_dict.items()}
        else:
            for key in averaged_state_dict:
                averaged_state_dict[key] += state_dict[key]

        
        # 결과 저장
        fold_results.append({
            "Fold": fold,
            "Test Accuracy": test_acc,
            "Confusion Matrix": None,
            "Classification Report": None
        })
        # learning curve
        fold_learning_curves.append({
            "Fold": fold,
            "train_losses": train_losses,
            "test_losses": test_losses,
            "train_accuracies": train_accuracies,
            "test_accuracies": test_accuracies
        })

        # confusion matrix
        y_true, y_pred = [], []
        packnet_model.eval()
        with torch.no_grad():
            for batch in test_loader:
                input_ids_batch = batch['input_ids'].to(device)
                labels_batch = batch['labels'].to(device)

                outputs = packnet_model(input_ids_batch)
                _, predicted = torch.max(outputs, 1)
                y_true.extend(labels_batch.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())

        cm = confusion_matrix(y_true, y_pred)
        if cm.ndim != 2:
            raise ValueError(f"Confusion Matrix must be 2D, but got shape {cm.shape}.")
        report = classification_report(y_true, y_pred, output_dict=True)

        fold_results.append({
        "Fold": fold,
        "Test Accuracy": test_acc,
        "Confusion Matrix": cm,
        "Classification Report": classification_report(y_true, y_pred, output_dict=True)
        })
    
    avg_accuracy = np.mean([fr["Test Accuracy"] for fr in fold_results])
    results.append({
        "Group": idx,
        "Average Test Accuracy": avg_accuracy,
        "Fold Results": fold_results
    })

    all_learning_curves.append({
        "Group": idx,
        "Fold Learning Curves": fold_learning_curves
    })

    print(f"\n그룹 {idx}의 {K_FOLDS} 폴드 평균 테스트 정확도: {avg_accuracy:.4f}")

    # 평균 모델 저장
    checkpoint_path = f"/home/jisoo/Perceiver/Perceiver/checkpoints/group_{idx}_average_model.pth.tar"
    torch.save(averaged_state_dict, checkpoint_path)
    print(f"Average model checkpoint for Group {idx} saved at {checkpoint_path}")

    avg_accuracy = np.mean([fr["Test Accuracy"] for fr in fold_results])

    for curve in fold_learning_curves:
        fold_idx = curve["Fold"]

      
        # plt.figure(figsize=(10, 6))
        # plt.plot(range(1, EPOCHS + 1), curve["train_losses"], label="Train Loss")
        # plt.plot(range(1, EPOCHS + 1), curve["test_losses"], label="Test Loss")
        # plt.title(f"Group {idx} - Fold {fold_idx} Learning Curve (Loss)")
        # plt.xlabel("Epoch")
        # plt.ylabel("Loss")
        # plt.legend()
        # plt.grid(True)
        # plt.show()

    
        # plt.figure(figsize=(10, 6))
        # plt.plot(range(1, EPOCHS + 1), curve["train_accuracies"], label="Train Accuracy")
        # plt.plot(range(1, EPOCHS + 1), curve["test_accuracies"], label="Test Accuracy")
        # plt.title(f"Group {idx} - Fold {fold_idx} Learning Curve (Accuracy)")
        # plt.xlabel("Epoch")
        # plt.ylabel("Accuracy")
        # plt.legend()
        # plt.grid(True)
        # plt.show()

    for fold_result in fold_results:
        fold_idx = fold_result["Fold"]
        cm = fold_result["Confusion Matrix"]

        if cm is not None and cm.ndim == 2:
            # plt.figure(figsize=(10, 8))
            # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            #             xticklabels=label_encoder.classes_,
            #             yticklabels=label_encoder.classes_)
            # plt.title(f"Group {idx} - Fold {fold_idx} Confusion Matrix")
            # plt.xlabel("Predicted")
            # plt.ylabel("Actual")
            # plt.show()
            print("plt 뽑히려던 곳 ")
        else:
            print(f"Confusion Matrix for Fold {fold_idx} is invalid or missing.")


Group 1 처리 중...

  Fold 1/5 처리 중...
Pruning 이전 성능:
  Test Loss: 1.1050, Test Accuracy: 0.3211
epoch 1/20: train loss 1.1071, train acc 0.3487
                         test loss 1.0765, test acc 0.4196
epoch 5/20: train loss 0.5272, train acc 0.7929
                         test loss 0.5932, test acc 0.7642
epoch 10/20: train loss 0.3620, train acc 0.8683
                         test loss 0.4440, test acc 0.8266
epoch 15/20: train loss 0.2931, train acc 0.8957
                         test loss 0.4660, test acc 0.8204
epoch 20/20: train loss 0.2579, train acc 0.9120
                         test loss 0.4483, test acc 0.8377

  Fold 2/5 처리 중...
Pruning 이전 성능:
  Test Loss: 1.1994, Test Accuracy: 0.3366
epoch 1/20: train loss 1.1100, train acc 0.3525
                         test loss 1.0637, test acc 0.3914
epoch 5/20: train loss 0.4674, train acc 0.8231
                         test loss 0.5011, test acc 0.8050
epoch 10/20: train loss 0.3420, train acc 0.8772
                         t

In [17]:
def prune_model(group_idx, checkpoint_dir, start_sparsity, end_sparsity, pruning_ratio, device):
    
    # 기존 모델 로딩
    checkpoint_path = f"{checkpoint_dir}/group_{group_idx}_average_model.pth.tar"
    print(f"Loading checkpoint for Group {group_idx} from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    num_classes = checkpoint.get('num_classes', checkpoint['model.perceiver.output_layer.weight'].size(0))

    perceiver = Perceiver(
        input_dim=128, 
        latent_dim=64,
        latent_size=64,
        num_classes=num_classes,
        num_blocks=4,
        self_attn_layers_per_block=1
    )

    combined_model = CombinedModel(
        vocab_size=tokenizer.vocab_size,
        embed_dim=128,
        perceiver_model=perceiver
    )

    packnet_model = PackNet(combined_model)
    model_state_dict = checkpoint

    # 일치하지 않는 키 필터링
    filtered_state_dict = {
        key: value
        for key, value in model_state_dict.items()
        if key in packnet_model.state_dict() and packnet_model.state_dict()[key].size() == value.size()
    }
    packnet_model.load_state_dict(filtered_state_dict, strict=False)
    packnet_model.to(device)
    packnet_model.set_task(f"group_{group_idx}_pruning")
    print("Model successfully loaded with matched parameters.")

    # 가중치 보고 pruning
    current_sparsity = start_sparsity
    while current_sparsity <= end_sparsity:
        print(f"Applying pruning with sparsity: {current_sparsity:.2f}")
        for name, module in packnet_model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                weights = module.weight.data.abs().flatten()
                threshold = torch.quantile(weights, current_sparsity)
                mask = module.weight.data.abs() >= threshold
                module.weight.data *= mask

                # 마스크 선택적으로 저장
                if not hasattr(module, 'mask'):
                    module.mask = mask

        current_sparsity += pruning_ratio

    print("Pruning completed.")

    # pruned model 저장
    pruned_checkpoint_path = f"{checkpoint_dir}/group_{group_idx}_pruned_model.pth.tar"
    torch.save({
        "model_state_dict": packnet_model.state_dict(),
        "num_classes": num_classes,  # Replace with actual number of classes
        "vocab_size": tokenizer.vocab_size
    }, pruned_checkpoint_path)
    print(f"Pruned model for Group {group_idx} saved at {pruned_checkpoint_path}.")

# pruning 파라미터
groups = [1, 2, 3, 4, 5, 6]  
checkpoint_dir = "/home/jisoo/Perceiver/Perceiver/checkpoints"
start_sparsity = 0.05
end_sparsity = 0.2
pruning_ratio = 0.05
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

for group_idx in groups:
    prune_model(
        group_idx=group_idx,
        checkpoint_dir=checkpoint_dir,
        start_sparsity=start_sparsity,
        end_sparsity=end_sparsity,
        pruning_ratio=pruning_ratio,
        device=device
    )

Loading checkpoint for Group 1 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_1_average_model.pth.tar...
Model successfully loaded with matched parameters.
Applying pruning with sparsity: 0.05
Applying pruning with sparsity: 0.10
Applying pruning with sparsity: 0.15
Applying pruning with sparsity: 0.20
Pruning completed.
Pruned model for Group 1 saved at /home/jisoo/Perceiver/Perceiver/checkpoints/group_1_pruned_model.pth.tar.
Loading checkpoint for Group 2 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_2_average_model.pth.tar...
Model successfully loaded with matched parameters.
Applying pruning with sparsity: 0.05
Applying pruning with sparsity: 0.10
Applying pruning with sparsity: 0.15
Applying pruning with sparsity: 0.20
Pruning completed.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Pruned model for Group 2 saved at /home/jisoo/Perceiver/Perceiver/checkpoints/group_2_pruned_model.pth.tar.
Loading checkpoint for Group 3 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_3_average_model.pth.tar...
Model successfully loaded with matched parameters.
Applying pruning with sparsity: 0.05
Applying pruning with sparsity: 0.10
Applying pruning with sparsity: 0.15
Applying pruning with sparsity: 0.20
Pruning completed.
Pruned model for Group 3 saved at /home/jisoo/Perceiver/Perceiver/checkpoints/group_3_pruned_model.pth.tar.
Loading checkpoint for Group 4 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_4_average_model.pth.tar...
Model successfully loaded with matched parameters.
Applying pruning with sparsity: 0.05
Applying pruning with sparsity: 0.10
Applying pruning with sparsity: 0.15
Applying pruning with sparsity: 0.20
Pruning completed.
Pruned model for Group 4 saved at /home/jisoo/Perceiver/Perceiver/checkpoints/group_4_pruned_model.pth.tar.
Loading checkp

In [18]:
def prepare_test_loader(group_idx, batch_size=32):
    
    group_file = output_paths[group_idx - 1]
    df = pd.read_csv(group_file)

   
    input_ids, attention_masks = tokenize_data(df)
    label_encoder = LabelEncoder()
    df['Label'] = label_encoder.fit_transform(df['Label'])
    labels = torch.tensor(df['Label'].values)
    test_dataset = CustomDataset(input_ids, attention_masks, labels)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return test_loader, label_encoder

def load_pruned_model(group_idx, checkpoint_dir, device):
    
    checkpoint_path = f"{checkpoint_dir}/group_{group_idx}_pruned_model.pth.tar"
    print(f"Loading pruned model for Group {group_idx} from {checkpoint_path}...")

    # checkpoint 로드
    checkpoint = torch.load(checkpoint_path, map_location=device)
    num_classes = checkpoint.get('num_classes', 10)  # Use default if not in checkpoint

    # 모델 초기화
    perceiver = Perceiver(
        input_dim=128,  # Match embedding dimension
        latent_dim=64,
        latent_size=64,
        num_classes=num_classes,
        num_blocks=4,
        self_attn_layers_per_block=1
    )

    combined_model = CombinedModel(
        vocab_size=checkpoint.get('vocab_size', tokenizer.vocab_size),
        embed_dim=128,
        perceiver_model=perceiver
    )

    packnet_model = PackNet(combined_model)
    packnet_model.load_state_dict(checkpoint['model_state_dict'])
    packnet_model.to(device)
    packnet_model.set_task(f"group_{group_idx}_evaluation")
    print("Pruned model loaded successfully.")

    return packnet_model

In [19]:
def evaluate_pruned_model(packnet_model, test_loader, criterion, device):
    packnet_model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            outputs = packnet_model(input_ids)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total
    return avg_loss, accuracy

checkpoint_dir = "/home/jisoo/Perceiver/Perceiver/checkpoints"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

for group_idx in range(1, len(output_paths) + 1):
    # dataloader 불러오기
    test_loader, label_encoder = prepare_test_loader(group_idx, batch_size=32)

    pruned_model = load_pruned_model(group_idx, checkpoint_dir, device)
    criterion = nn.CrossEntropyLoss()

    # evaluate
    test_loss, test_accuracy = evaluate_pruned_model(pruned_model, test_loader, criterion, device)
    print(f"Group {group_idx}: Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")


Loading pruned model for Group 1 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_1_pruned_model.pth.tar...
Pruned model loaded successfully.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Group 1: Test Loss: 13.5007, Test Accuracy: 0.3382
Loading pruned model for Group 2 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_2_pruned_model.pth.tar...
Pruned model loaded successfully.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Group 2: Test Loss: 5.9509, Test Accuracy: 0.3808
Loading pruned model for Group 3 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_3_pruned_model.pth.tar...
Pruned model loaded successfully.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Group 3: Test Loss: 7.1559, Test Accuracy: 0.3536
Loading pruned model for Group 4 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_4_pruned_model.pth.tar...
Pruned model loaded successfully.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Group 4: Test Loss: 12.8544, Test Accuracy: 0.3455
Loading pruned model for Group 5 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_5_pruned_model.pth.tar...
Pruned model loaded successfully.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Group 5: Test Loss: 30.2527, Test Accuracy: 0.3357
Loading pruned model for Group 6 from /home/jisoo/Perceiver/Perceiver/checkpoints/group_6_pruned_model.pth.tar...
Pruned model loaded successfully.


  checkpoint = torch.load(checkpoint_path, map_location=device)


Group 6: Test Loss: 9.8103, Test Accuracy: 0.3455
