In [5]:
import pandas as pd
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import BertTokenizer
from sklearn.preprocessing import LabelEncoder
from perceiver import crop, patchify, get_patch_coords, tokenize_data, CustomDataset, PerceiverBlock, Perceiver, CombinedModel, ImageDataset

# Device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

# 경로 설정
pruned_model_dir = "/home/youlee/perceiver/perceiver/checkpoints_pruned2/"
batch_size = 32

In [6]:
# 이미지 데이터 전처리 및 로드
image_root_dir = "/home/youlee/n24news/n24news/image"
class_groups = [
    ["Opinion", "Art & Design", "Television"],
    ["Music", "Travel", "Real Estate"],
    ["Books", "Theater", "Health"],
    ["Sports", "Science", "Food"],
    ["Fashion & Style", "Movies", "Technology"],
    ["Dance", "Media", "Style"]
]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

image_datasets = [ImageDataset(image_root_dir, transform=transform, selected_classes=group) for group in class_groups]
image_loaders = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in image_datasets]

# 텍스트 데이터 전처리 및 로드
text_root_dir = "/home/youlee/n24news/n24news/"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

label_encoder = LabelEncoder()

def load_text_data(file_path):
    df = pd.read_csv(file_path)
    
    input_ids, attention_masks = tokenize_data(df, tokenizer=tokenizer, MAX_LENGTH=128)
    
    if df["Label"].dtype == "object":
        df["Label"] = label_encoder.fit_transform(df["Label"])

    labels = torch.tensor(df["Label"].values, dtype=torch.long)
    return CustomDataset(input_ids, attention_masks, labels)

text_datasets = [load_text_data(f"{text_root_dir}regroup_{i}.csv") for i in range(1, 7)]
text_loaders = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in text_datasets]


In [7]:
# Pruned 모델 로드 함수
def load_pruned_models(pruned_model_dir, model_dir):
    pruned_text_models, pruned_image_models = [], []
    
    # Text 모델 로드
    for i in range(6):
        file_path = f"{pruned_model_dir}/text_model_{i+1}_pruned.pkl"
        model_path = f"{model_dir}/text_model_{i+1}.pkl"
        
        with open(file_path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = torch.load(model_path)
        state_dict = {k.replace("model.", ""): v for k, v in model_data["model_state_dict"].items()}
        model.load_state_dict(state_dict, strict=False)
        pruned_text_models.append(model)

    # Image 모델 로드
    for i in range(6):
        file_path = f"{pruned_model_dir}/image_model_{i+1}_pruned.pkl"
        model_path = f"{model_dir}/image_model_{i+1}.pkl"
        
        with open(file_path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = torch.load(model_path)
        state_dict = {k.replace("model.", ""): v for k, v in model_data["model_state_dict"].items()}
        model.load_state_dict(state_dict, strict=False)
        pruned_image_models.append(model)

    return pruned_text_models, pruned_image_models


In [8]:
# # Validation 데이터 로드 함수
# def load_valid_loaders():
#     valid_loaders = []
#     for i in range(6):
#         with open(f"{loader_dir}text_val_loader_{i+1}.pkl", 'rb') as f:
#             valid_loaders.append(pickle.load(f))
#     for i in range(6):
#         with open(f"{loader_dir}image_val_loader_{i+1}.pkl", 'rb') as f:
#             valid_dataset = pickle.load(f)
#             valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
#             valid_loaders.append(valid_loader)
#     return valid_loaders

In [9]:
# 키 매핑 함수
def map_keys(state_dict, modality):
    key_map = {}
    for key in state_dict.keys():
        if modality == "Text" and "model.perceiver" in key:
            key_map[key] = key.replace("model.perceiver", "model")
        elif modality == "Image" and "model." in key:
            key_map[key] = key.replace("model.", "model.perceiver")
        else:
            key_map[key] = key
    return key_map

# 가중치 매핑 및 로드 함수
def apply_mapped_state_dict(task_model, backbone_model, modality):
    backbone_state_dict = backbone_model.state_dict()
    key_map = map_keys(backbone_state_dict, modality)
    mapped_state_dict = {key_map[k]: v for k, v in backbone_state_dict.items()}
    task_model.load_state_dict(mapped_state_dict, strict=False)
    print(f"Successfully applied backbone weights for {modality} task.")

In [10]:
def knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                       task_models, backbone_models, valid_loaders, criterion, device, epochs=20):
    if backbone_modality == "Text":
        backbone_model = backbone_models[backbone_id]
    else:
        backbone_model = backbone_models[backbone_id - 6]

    task_model = task_models[task_id]
    task_model.to(device)

    apply_mapped_state_dict(task_model, backbone_model, backbone_modality)

    optimizer = optim.SGD(task_model.parameters(), lr=0.001, momentum=0.9)
    task_model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch in valid_loaders[task_id]:
            optimizer.zero_grad()
            if task_modality == "Text":
                inputs = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                outputs = task_model(inputs)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = task_model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Task {task_id} | Backbone {backbone_id} | Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Accuracy: {correct/total:.4f}")
    return task_model

In [11]:
def eval_epoch(model, dataloader, criterion, device, is_text: bool):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            if is_text:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)

            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

In [12]:
# 최종 평가 함수
def evaluate_transfer(models, valid_loaders, criterion, device):
    for i, model in enumerate(models):
        is_text = i < 6
        test_loss, test_acc = eval_epoch(model, valid_loaders[i], criterion, device, is_text)
        modality = "Text" if is_text else "Image"
        print(f"Task {i} ({modality}) | Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

In [13]:
# Cosine 및 Euclidean 결과 로드
cosine_results = pd.read_csv("/home/youlee/perceiver/perceiver/code/best_cosine_results.txt", sep='\t')
euclidean_results = pd.read_csv("/home/youlee/perceiver/perceiver/code/best_euclidean_results.txt", sep='\t')
cosine_task_to_backbone = {row["Task_ID"]: row["Best_Target_ID"] for _, row in cosine_results.iterrows()}
euclidean_task_to_backbone = {row["Task_ID"]: row["Best_Target_ID"] for _, row in euclidean_results.iterrows()}

In [None]:
if __name__ == "__main__":
    print("Loading pruned models and validation loaders...")
    
    pruned_text_models, pruned_image_models = load_pruned_models(pruned_model_dir, "/home/youlee/perceiver/perceiver/model")
    all_loaders = text_loaders + image_loaders
    pruned_models = pruned_text_models + pruned_image_models

    print("Starting Cosine-based Knowledge Transfer...")
    for task_id, backbone_id in cosine_task_to_backbone.items():
        backbone_modality = cosine_results.loc[cosine_results["Task_ID"] == task_id, "Target_Modality"].values[0]
        task_modality = cosine_results.loc[cosine_results["Task_ID"] == task_id, "Task_Modality"].values[0]
        knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                           pruned_models, pruned_models, all_loaders, criterion, device)

    print("Starting Euclidean-based Knowledge Transfer...")
    for task_id, backbone_id in euclidean_task_to_backbone.items():
        backbone_modality = euclidean_results.loc[euclidean_results["Task_ID"] == task_id, "Target_Modality"].values[0]
        task_modality = euclidean_results.loc[euclidean_results["Task_ID"] == task_id, "Task_Modality"].values[0]
        knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                           pruned_models, pruned_models, all_loaders, criterion, device)

    print("Evaluating final results...")
    evaluate_transfer(pruned_models, all_loaders, criterion, device)

Loading pruned models and validation loaders...


  model = torch.load(model_path)
  model = torch.load(model_path)


Starting Cosine-based Knowledge Transfer...
Successfully applied backbone weights for Image task.
Task 0 | Backbone 9 | Epoch 1/20 | Loss: 244.0055 | Accuracy: 0.5415
Task 0 | Backbone 9 | Epoch 2/20 | Loss: 180.0655 | Accuracy: 0.6369
Task 0 | Backbone 9 | Epoch 3/20 | Loss: 158.1939 | Accuracy: 0.7104
Task 0 | Backbone 9 | Epoch 4/20 | Loss: 140.0918 | Accuracy: 0.7601
Task 0 | Backbone 9 | Epoch 5/20 | Loss: 128.5700 | Accuracy: 0.7861
Task 0 | Backbone 9 | Epoch 6/20 | Loss: 120.4406 | Accuracy: 0.7990
Task 0 | Backbone 9 | Epoch 7/20 | Loss: 114.3906 | Accuracy: 0.8165
Task 0 | Backbone 9 | Epoch 8/20 | Loss: 109.5414 | Accuracy: 0.8209
Task 0 | Backbone 9 | Epoch 9/20 | Loss: 105.1199 | Accuracy: 0.8289
Task 0 | Backbone 9 | Epoch 10/20 | Loss: 101.7848 | Accuracy: 0.8330
Task 0 | Backbone 9 | Epoch 11/20 | Loss: 98.9302 | Accuracy: 0.8379
Task 0 | Backbone 9 | Epoch 12/20 | Loss: 96.1539 | Accuracy: 0.8438
Task 0 | Backbone 9 | Epoch 13/20 | Loss: 93.7768 | Accuracy: 0.8484
Task

In [10]:
import pickle

file_path = "/home/youlee/perceiver/perceiver/checkpoints_pruned/text_model_1_pruned.pkl"

try:
    with open(file_path, 'rb') as f:
        model_data = pickle.load(f)
    print("Successfully loaded the file. Keys in the file:")
    print(model_data.keys())
except Exception as e:
    print(f"Error while opening the file: {e}")


# model_state_dict 내부 키 확인
state_dict = model_data["model_state_dict"]
print("Keys in model_state_dict:")
print(list(state_dict.keys())[:10])  # 첫 10개의 키

# masks 내부 키 확인
masks = model_data["masks"]
print("Keys in masks:")
print(list(masks.keys())[:10])  # 첫 10개의 키


Successfully loaded the file. Keys in the file:
dict_keys(['model_state_dict', 'masks'])
Keys in model_state_dict:
['model.embedding.weight', 'model.perceiver.latents', 'model.perceiver.input_projection.weight', 'model.perceiver.input_projection.bias', 'model.perceiver.blocks.0.cross_attn.in_proj_weight', 'model.perceiver.blocks.0.cross_attn.in_proj_bias', 'model.perceiver.blocks.0.cross_attn.out_proj.weight', 'model.perceiver.blocks.0.cross_attn.out_proj.bias', 'model.perceiver.blocks.0.cross_ln.weight', 'model.perceiver.blocks.0.cross_ln.bias']
Keys in masks:
['text_task_1']


In [11]:
import pickle

file_path = "/home/youlee/perceiver/perceiver/checkpoints_pruned/image_model_1_pruned.pkl"

try:
    with open(file_path, 'rb') as f:
        model_data = pickle.load(f)
    print("Successfully loaded the file. Keys in the file:")
    print(model_data.keys())
except Exception as e:
    print(f"Error while opening the file: {e}")


# model_state_dict 내부 키 확인
state_dict = model_data["model_state_dict"]
print("Keys in model_state_dict:")
print(list(state_dict.keys())[:10])  # 첫 10개의 키

# masks 내부 키 확인
masks = model_data["masks"]
print("Keys in masks:")
print(list(masks.keys())[:10])  # 첫 10개의 키

Successfully loaded the file. Keys in the file:
dict_keys(['model_state_dict', 'masks'])
Keys in model_state_dict:
['model.latents', 'model.input_projection.weight', 'model.input_projection.bias', 'model.blocks.0.cross_attn.in_proj_weight', 'model.blocks.0.cross_attn.in_proj_bias', 'model.blocks.0.cross_attn.out_proj.weight', 'model.blocks.0.cross_attn.out_proj.bias', 'model.blocks.0.cross_ln.weight', 'model.blocks.0.cross_ln.bias', 'model.blocks.0.self_attn_layers.0.self_attn.in_proj_weight']
Keys in masks:
['image_task_1']
