In [4]:
import pandas as pd
import torch
import copy
import numpy as np
import random
import pickle
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from transformers import BertTokenizer
from sklearn.preprocessing import LabelEncoder
from perceiver import crop, patchify, get_patch_coords, tokenize_data, CustomDataset, PerceiverBlock, Perceiver, CombinedModel, ImageDataset

# Device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

seed_everything(42)

# 경로 설정
pruned_model_dir = "/home/youlee/perceiver/perceiver/checkpoints_pruned2/"
batch_size = 32

In [5]:
# 이미지 데이터 전처리 및 로드
image_root_dir = "/home/youlee/n24news/n24news/image"
class_groups = [
    ["Opinion", "Art & Design", "Television"],
    ["Music", "Travel", "Real Estate"],
    ["Books", "Theater", "Health"],
    ["Sports", "Science", "Food"],
    ["Fashion & Style", "Movies", "Technology"],
    ["Dance", "Media", "Style"]
]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

image_datasets = [ImageDataset(image_root_dir, transform=transform, selected_classes=group) for group in class_groups]
# Train Valid Split 
image_train_loaders, image_valid_loaders = [], []
for dataset in image_datasets:
    train_size = int(len(dataset) * 0.8)
    valid_size = len(dataset) - train_size
    train_set, valid_set = random_split(dataset, [train_size, valid_size])
    image_train_loaders.append(DataLoader(train_set, batch_size=32, shuffle=True))
    image_valid_loaders.append(DataLoader(valid_set, batch_size=32, shuffle=False))

# 텍스트 데이터 전처리 및 로드
text_root_dir = "/home/youlee/n24news/n24news/"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
label_encoder = LabelEncoder()

def load_text_data(file_path):
    df = pd.read_csv(file_path)
    input_ids, attention_masks = tokenize_data(df, tokenizer=tokenizer, MAX_LENGTH=128)
    
    if df["Label"].dtype == "object":
        df["Label"] = label_encoder.fit_transform(df["Label"])
    labels = torch.tensor(df["Label"].values, dtype=torch.long)
    
    return CustomDataset(input_ids, attention_masks, labels)

text_datasets = [load_text_data(f"{text_root_dir}regroup_{i}.csv") for i in range(1, 7)]

# Train Valid Split 
text_train_loaders, text_valid_loaders = [], []
for dataset in text_datasets:
    train_size = int(len(dataset) * 0.8)
    valid_size = len(dataset) - train_size
    train_set, valid_set = random_split(dataset, [train_size, valid_size])
    text_train_loaders.append(DataLoader(train_set, batch_size=32, shuffle=True))
    text_valid_loaders.append(DataLoader(valid_set, batch_size=32, shuffle=False))

train_loaders = text_train_loaders + image_train_loaders
valid_loaders = text_valid_loaders + image_valid_loaders

In [6]:
# Pruned 모델 로드 함수
def load_pruned_models(pruned_model_dir, model_dir):
    pruned_text_models, pruned_image_models = [], []
    
    # Text 모델 로드
    for i in range(6):
        file_path = f"{pruned_model_dir}/text_model_{i+1}_pruned.pkl"
        model_path = f"{model_dir}/text_model_{i+1}.pkl"
        
        with open(file_path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = torch.load(model_path)
        state_dict = {k.replace("model.", ""): v for k, v in model_data["model_state_dict"].items()}
        model.load_state_dict(state_dict, strict=False)
        pruned_text_models.append(model)

    # Image 모델 로드
    for i in range(6):
        file_path = f"{pruned_model_dir}/image_model_{i+1}_pruned.pkl"
        model_path = f"{model_dir}/image_model_{i+1}.pkl"
        
        with open(file_path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = torch.load(model_path)
        state_dict = {k.replace("model.", ""): v for k, v in model_data["model_state_dict"].items()}
        model.load_state_dict(state_dict, strict=False)
        pruned_image_models.append(model)

    return pruned_text_models, pruned_image_models


In [7]:
# # Validation 데이터 로드 함수
# def load_valid_loaders():
#     valid_loaders = []
#     for i in range(6):
#         with open(f"{loader_dir}text_val_loader_{i+1}.pkl", 'rb') as f:
#             valid_loaders.append(pickle.load(f))
#     for i in range(6):
#         with open(f"{loader_dir}image_val_loader_{i+1}.pkl", 'rb') as f:
#             valid_dataset = pickle.load(f)
#             valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
#             valid_loaders.append(valid_loader)
#     return valid_loaders

In [8]:
# 키 매핑 함수
def map_keys(state_dict, modality):
    key_map = {}
    for key in state_dict.keys():
        if modality == "Text" and "model.perceiver" in key:
            key_map[key] = key.replace("model.perceiver", "model")
        elif modality == "Image" and "model." in key:
            key_map[key] = key.replace("model.", "model.perceiver")
        else:
            key_map[key] = key
    return key_map

# 가중치 매핑 및 로드 함수
def apply_mapped_state_dict(task_model, backbone_model, modality):
    backbone_state_dict = backbone_model.state_dict()
    key_map = map_keys(backbone_state_dict, modality)
    mapped_state_dict = {key_map[k]: v for k, v in backbone_state_dict.items()}
    task_model.load_state_dict(mapped_state_dict, strict=False)
    print(f"Successfully applied backbone weights for {modality} task.")

In [9]:
def knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                       task_models, backbone_models, train_loaders, criterion, device, epochs=20):
    if backbone_modality == "Text":
        backbone_model = backbone_models[backbone_id]
    elif backbone_modality == "Image":
        backbone_model = backbone_models[backbone_id]

    task_model = task_models[task_id]
    task_model.to(device)

    apply_mapped_state_dict(task_model, backbone_model, backbone_modality)

    optimizer = optim.SGD(task_model.parameters(), lr=0.0001, momentum=0.9)
    task_model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch in train_loaders[task_id]:
            optimizer.zero_grad()
            if task_modality == "Text":
                inputs = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                outputs = task_model(inputs)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = task_model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Task {task_id} | Backbone {backbone_id} | Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Accuracy: {correct/total:.4f}")
    return task_model

In [10]:
def eval_epoch(model, dataloader, criterion, device, is_text: bool):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    correct_samples = []
    incorrect_samples = []

    with torch.no_grad():
        for batch in dataloader:
            if is_text:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)

            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for i in range(len(labels)):
                if predicted[i] == labels[i] and len(correct_samples) < 4:
                    correct_samples.append((i, labels[i].item(), predicted[i].item()))
                elif predicted[i] != labels[i] and len(incorrect_samples) < 4:
                    incorrect_samples.append((i, labels[i].item(), predicted[i].item()))
            if len(correct_samples) >= 4 and len(incorrect_samples) >= 4:
                break


    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy, correct_samples, incorrect_samples

In [11]:
# 최종 평가 함수
def evaluate_transfer(models, valid_loaders, criterion, device):
    correct_samples_dict = {}  # Task별 정답 샘플 저장
    incorrect_samples_dict = {}  # Task별 오답 샘플 저장

    for i, model in enumerate(models):
        is_text = i < 6

        test_loss, test_acc, correct_samples, incorrect_samples = eval_epoch(model, valid_loaders[i], criterion, device, is_text)
        modality = "Text" if is_text else "Image"
        print(f"Task {i} ({modality}) | Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

        # Task ID별로 샘플 저장
        correct_samples_dict[i] = correct_samples
        incorrect_samples_dict[i] = incorrect_samples

    return correct_samples_dict, incorrect_samples_dict

In [12]:
# Cosine 및 Euclidean 결과 로드
cosine_results = pd.read_csv("/home/youlee/perceiver/perceiver/code/best_cosine_results.txt", sep='\t')
euclidean_results = pd.read_csv("/home/youlee/perceiver/perceiver/code/best_euclidean_results.txt", sep='\t')
cosine_task_to_backbone = {row["Task_ID"]: row["Best_Target_ID"] for _, row in cosine_results.iterrows()}
euclidean_task_to_backbone = {row["Task_ID"]: row["Best_Target_ID"] for _, row in euclidean_results.iterrows()}

In [13]:
if __name__ == "__main__":
    print("Loading pruned models and validation loaders...")
    
    pruned_text_models, pruned_image_models = load_pruned_models(pruned_model_dir, "/home/youlee/perceiver/perceiver/model")
    pruned_models = pruned_text_models + pruned_image_models

    cosine_models = [copy.deepcopy(model).to(device) for model in pruned_models]
    euclidean_models = [copy.deepcopy(model).to(device) for model in pruned_models]


    print("Starting Cosine-based Knowledge Transfer...")
    for task_id, backbone_id in cosine_task_to_backbone.items():
        backbone_modality = cosine_results.loc[cosine_results["Task_ID"] == task_id, "Target_Modality"].values[0]
        task_modality = cosine_results.loc[cosine_results["Task_ID"] == task_id, "Task_Modality"].values[0]
        knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                           cosine_models, pruned_models, train_loaders, criterion, device)

    print("Starting Euclidean-based Knowledge Transfer...")
    for task_id, backbone_id in euclidean_task_to_backbone.items():
        backbone_modality = euclidean_results.loc[euclidean_results["Task_ID"] == task_id, "Target_Modality"].values[0]
        task_modality = euclidean_results.loc[euclidean_results["Task_ID"] == task_id, "Task_Modality"].values[0]
        knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                           euclidean_models, pruned_models, train_loaders, criterion, device)

    print("Evaluating Cosine-based Transfer results...")
    cosine_correct_samples, cosine_incorrect_samples = evaluate_transfer(cosine_models, valid_loaders, criterion, device)


    print("Evaluating Euclidean-based Transfer results...")
    euclidean_correct_samples, euclidean_incorrect_samples = evaluate_transfer(euclidean_models, valid_loaders, criterion, device)

    # ✅ Task별 정답/오답 샘플 최종 출력
    print("\n===== Cosine-based Correct & Incorrect Predictions =====")
    for task_id in cosine_correct_samples.keys():
        print(f"✅ Task {task_id} Correct Predictions:")
        for idx, true_label, pred_label in cosine_correct_samples[task_id]:
            print(f"  - Sample {idx}: True Label={true_label}, Predicted Label={pred_label}")

        print(f"❌ Task {task_id} Incorrect Predictions:")
        for idx, true_label, pred_label in cosine_incorrect_samples[task_id]:
            print(f"  - Sample {idx}: True Label={true_label}, Predicted Label={pred_label}")
        print("-" * 60)

    print("\n===== Euclidean-based Correct & Incorrect Predictions =====")
    for task_id in euclidean_correct_samples.keys():
        print(f"✅ Task {task_id} Correct Predictions:")
        for idx, true_label, pred_label in euclidean_correct_samples[task_id]:
            print(f"  - Sample {idx}: True Label={true_label}, Predicted Label={pred_label}")

        print(f"❌ Task {task_id} Incorrect Predictions:")
        for idx, true_label, pred_label in euclidean_incorrect_samples[task_id]:
            print(f"  - Sample {idx}: True Label={true_label}, Predicted Label={pred_label}")
        print("-" * 60)

Loading pruned models and validation loaders...


  model = torch.load(model_path)
  model = torch.load(model_path)


Starting Cosine-based Knowledge Transfer...
Successfully applied backbone weights for Image task.
Task 0 | Backbone 9 | Epoch 1/20 | Loss: 36.2878 | Accuracy: 0.9480
Task 0 | Backbone 9 | Epoch 2/20 | Loss: 29.9958 | Accuracy: 0.9587
Task 0 | Backbone 9 | Epoch 3/20 | Loss: 27.8242 | Accuracy: 0.9616
Task 0 | Backbone 9 | Epoch 4/20 | Loss: 26.1543 | Accuracy: 0.9653
Task 0 | Backbone 9 | Epoch 5/20 | Loss: 25.1702 | Accuracy: 0.9664
Task 0 | Backbone 9 | Epoch 6/20 | Loss: 24.1259 | Accuracy: 0.9674
Task 0 | Backbone 9 | Epoch 7/20 | Loss: 23.2531 | Accuracy: 0.9700
Task 0 | Backbone 9 | Epoch 8/20 | Loss: 22.4475 | Accuracy: 0.9710
Task 0 | Backbone 9 | Epoch 9/20 | Loss: 22.0822 | Accuracy: 0.9720
Task 0 | Backbone 9 | Epoch 10/20 | Loss: 21.2705 | Accuracy: 0.9731
Task 0 | Backbone 9 | Epoch 11/20 | Loss: 20.9047 | Accuracy: 0.9738
Task 0 | Backbone 9 | Epoch 12/20 | Loss: 20.6284 | Accuracy: 0.9748
Task 0 | Backbone 9 | Epoch 13/20 | Loss: 20.1434 | Accuracy: 0.9758
Task 0 | Backb

In [10]:
import pickle

file_path = "/home/youlee/perceiver/perceiver/checkpoints_pruned/text_model_1_pruned.pkl"

try:
    with open(file_path, 'rb') as f:
        model_data = pickle.load(f)
    print("Successfully loaded the file. Keys in the file:")
    print(model_data.keys())
except Exception as e:
    print(f"Error while opening the file: {e}")


# model_state_dict 내부 키 확인
state_dict = model_data["model_state_dict"]
print("Keys in model_state_dict:")
print(list(state_dict.keys())[:10])  # 첫 10개의 키

# masks 내부 키 확인
masks = model_data["masks"]
print("Keys in masks:")
print(list(masks.keys())[:10])  # 첫 10개의 키


Successfully loaded the file. Keys in the file:
dict_keys(['model_state_dict', 'masks'])
Keys in model_state_dict:
['model.embedding.weight', 'model.perceiver.latents', 'model.perceiver.input_projection.weight', 'model.perceiver.input_projection.bias', 'model.perceiver.blocks.0.cross_attn.in_proj_weight', 'model.perceiver.blocks.0.cross_attn.in_proj_bias', 'model.perceiver.blocks.0.cross_attn.out_proj.weight', 'model.perceiver.blocks.0.cross_attn.out_proj.bias', 'model.perceiver.blocks.0.cross_ln.weight', 'model.perceiver.blocks.0.cross_ln.bias']
Keys in masks:
['text_task_1']


In [11]:
import pickle

file_path = "/home/youlee/perceiver/perceiver/checkpoints_pruned/image_model_1_pruned.pkl"

try:
    with open(file_path, 'rb') as f:
        model_data = pickle.load(f)
    print("Successfully loaded the file. Keys in the file:")
    print(model_data.keys())
except Exception as e:
    print(f"Error while opening the file: {e}")


# model_state_dict 내부 키 확인
state_dict = model_data["model_state_dict"]
print("Keys in model_state_dict:")
print(list(state_dict.keys())[:10])  # 첫 10개의 키

# masks 내부 키 확인
masks = model_data["masks"]
print("Keys in masks:")
print(list(masks.keys())[:10])  # 첫 10개의 키

Successfully loaded the file. Keys in the file:
dict_keys(['model_state_dict', 'masks'])
Keys in model_state_dict:
['model.latents', 'model.input_projection.weight', 'model.input_projection.bias', 'model.blocks.0.cross_attn.in_proj_weight', 'model.blocks.0.cross_attn.in_proj_bias', 'model.blocks.0.cross_attn.out_proj.weight', 'model.blocks.0.cross_attn.out_proj.bias', 'model.blocks.0.cross_ln.weight', 'model.blocks.0.cross_ln.bias', 'model.blocks.0.self_attn_layers.0.self_attn.in_proj_weight']
Keys in masks:
['image_task_1']
