In [1]:
import pandas as pd
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from perceiver import tokenize_data, CustomDataset, PerceiverBlock, Perceiver, CombinedModel

# Device 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()

# 경로 설정
pruned_model_dir = "/home/youlee/perceiver/perceiver/checkpoints_pruned/"
loader_dir = "/home/youlee/perceiver/perceiver/loader/"
batch_size = 32

In [2]:
# Pruned 모델 로드 함수
def load_pruned_models(pruned_model_dir, model_dir):
    pruned_text_models, pruned_image_models = [], []
    
    # Text 모델 로드
    for i in range(6):
        file_path = f"{pruned_model_dir}/text_model_{i+1}_pruned.pkl"
        model_path = f"{model_dir}/text_model_{i+1}.pkl"
        
        with open(file_path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = torch.load(model_path)
        state_dict = {k.replace("model.", ""): v for k, v in model_data["model_state_dict"].items()}
        model.load_state_dict(state_dict, strict=False)
        pruned_text_models.append(model)

    # Image 모델 로드
    for i in range(6):
        file_path = f"{pruned_model_dir}/image_model_{i+1}_pruned.pkl"
        model_path = f"{model_dir}/image_model_{i+1}.pkl"
        
        with open(file_path, 'rb') as f:
            model_data = pickle.load(f)
        
        model = torch.load(model_path)
        state_dict = {k.replace("model.", ""): v for k, v in model_data["model_state_dict"].items()}
        model.load_state_dict(state_dict, strict=False)
        pruned_image_models.append(model)

    return pruned_text_models, pruned_image_models

In [3]:
# Validation 데이터 로드 함수
def load_valid_loaders():
    valid_loaders = []
    for i in range(6):
        with open(f"{loader_dir}text_val_loader_{i+1}.pkl", 'rb') as f:
            valid_loaders.append(pickle.load(f))
    for i in range(6):
        with open(f"{loader_dir}image_val_loader_{i+1}.pkl", 'rb') as f:
            valid_dataset = pickle.load(f)
            valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
            valid_loaders.append(valid_loader)
    return valid_loaders

In [4]:
# 키 매핑 함수
def map_keys(state_dict, modality):
    key_map = {}
    for key in state_dict.keys():
        if modality == "Text" and "model.perceiver" in key:
            key_map[key] = key.replace("model.perceiver", "model")
        elif modality == "Image" and "model." in key:
            key_map[key] = key.replace("model.", "model.perceiver")
        else:
            key_map[key] = key
    return key_map

# 가중치 매핑 및 로드 함수
def apply_mapped_state_dict(task_model, backbone_model, modality):
    backbone_state_dict = backbone_model.state_dict()
    key_map = map_keys(backbone_state_dict, modality)
    mapped_state_dict = {key_map[k]: v for k, v in backbone_state_dict.items()}
    task_model.load_state_dict(mapped_state_dict, strict=False)
    print(f"Successfully applied backbone weights for {modality} task.")

In [5]:
def knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                       task_models, backbone_models, valid_loaders, criterion, device, epochs=20):
    if backbone_modality == "Text":
        backbone_model = backbone_models[backbone_id]
    else:
        backbone_model = backbone_models[backbone_id - 6]

    task_model = task_models[task_id]
    task_model.to(device)

    apply_mapped_state_dict(task_model, backbone_model, backbone_modality)

    optimizer = optim.SGD(task_model.parameters(), lr=0.001, momentum=0.9)
    task_model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        total = 0

        for batch in valid_loaders[task_id]:
            optimizer.zero_grad()
            if task_modality == "Text":
                inputs = batch["input_ids"].to(device)
                labels = batch["labels"].to(device)
                outputs = task_model(inputs)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = task_model(inputs)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Task {task_id} | Backbone {backbone_id} | Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} | Accuracy: {correct/total:.4f}")
    return task_model

In [6]:
def eval_epoch(model, dataloader, criterion, device, is_text: bool):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            if is_text:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids)
            else:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)

            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

In [7]:
# 최종 평가 함수
def evaluate_transfer(models, valid_loaders, criterion, device):
    for i, model in enumerate(models):
        is_text = i < 6
        test_loss, test_acc = eval_epoch(model, valid_loaders[i], criterion, device, is_text)
        modality = "Text" if is_text else "Image"
        print(f"Task {i} ({modality}) | Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

In [8]:
# Cosine 및 Euclidean 결과 로드
cosine_results = pd.read_csv("/home/youlee/perceiver/perceiver/code/best_cosine_results.txt", sep='\t')
euclidean_results = pd.read_csv("/home/youlee/perceiver/perceiver/code/best_euclidean_results.txt", sep='\t')
cosine_task_to_backbone = {row["Task_ID"]: row["Best_Target_ID"] for _, row in cosine_results.iterrows()}
euclidean_task_to_backbone = {row["Task_ID"]: row["Best_Target_ID"] for _, row in euclidean_results.iterrows()}

In [9]:
if __name__ == "__main__":
    print("Loading pruned models and validation loaders...")
    
    pruned_text_models, pruned_image_models = load_pruned_models(pruned_model_dir, "/home/youlee/perceiver/perceiver/model")
    valid_loaders = load_valid_loaders()
    pruned_models = pruned_text_models + pruned_image_models

    print("Starting Cosine-based Knowledge Transfer...")
    for task_id, backbone_id in cosine_task_to_backbone.items():
        backbone_modality = cosine_results.loc[cosine_results["Task_ID"] == task_id, "Target_Modality"].values[0]
        task_modality = cosine_results.loc[cosine_results["Task_ID"] == task_id, "Task_Modality"].values[0]
        knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                           pruned_models, pruned_models, valid_loaders, criterion, device)

    print("Starting Euclidean-based Knowledge Transfer...")
    for task_id, backbone_id in euclidean_task_to_backbone.items():
        backbone_modality = euclidean_results.loc[euclidean_results["Task_ID"] == task_id, "Target_Modality"].values[0]
        task_modality = euclidean_results.loc[euclidean_results["Task_ID"] == task_id, "Task_Modality"].values[0]
        knowledge_transfer(task_id, backbone_id, backbone_modality, task_modality, 
                           pruned_models, pruned_models, valid_loaders, criterion, device)

    print("Evaluating final results...")
    evaluate_transfer(pruned_models, valid_loaders, criterion, device)

Loading pruned models and validation loaders...


  model = torch.load(model_path)
  model = torch.load(model_path)


Starting Cosine-based Knowledge Transfer...
Successfully applied backbone weights for Image task.
Task 0 | Backbone 9 | Epoch 1/20 | Loss: 103.9075 | Accuracy: 0.4417
Task 0 | Backbone 9 | Epoch 2/20 | Loss: 44.6477 | Accuracy: 0.5336
Task 0 | Backbone 9 | Epoch 3/20 | Loss: 42.2105 | Accuracy: 0.5727
Task 0 | Backbone 9 | Epoch 4/20 | Loss: 40.4263 | Accuracy: 0.5988
Task 0 | Backbone 9 | Epoch 5/20 | Loss: 38.8032 | Accuracy: 0.6159
Task 0 | Backbone 9 | Epoch 6/20 | Loss: 37.1759 | Accuracy: 0.6440
Task 0 | Backbone 9 | Epoch 7/20 | Loss: 35.4552 | Accuracy: 0.6708
Task 0 | Backbone 9 | Epoch 8/20 | Loss: 33.6654 | Accuracy: 0.7003
Task 0 | Backbone 9 | Epoch 9/20 | Loss: 32.0442 | Accuracy: 0.7188
Task 0 | Backbone 9 | Epoch 10/20 | Loss: 30.7255 | Accuracy: 0.7366
Task 0 | Backbone 9 | Epoch 11/20 | Loss: 29.6251 | Accuracy: 0.7483
Task 0 | Backbone 9 | Epoch 12/20 | Loss: 28.4551 | Accuracy: 0.7558
Task 0 | Backbone 9 | Epoch 13/20 | Loss: 27.3229 | Accuracy: 0.7709
Task 0 | Back

In [10]:
import pickle

file_path = "/home/youlee/perceiver/perceiver/checkpoints_pruned/text_model_1_pruned.pkl"

try:
    with open(file_path, 'rb') as f:
        model_data = pickle.load(f)
    print("Successfully loaded the file. Keys in the file:")
    print(model_data.keys())
except Exception as e:
    print(f"Error while opening the file: {e}")


# model_state_dict 내부 키 확인
state_dict = model_data["model_state_dict"]
print("Keys in model_state_dict:")
print(list(state_dict.keys())[:10])  # 첫 10개의 키

# masks 내부 키 확인
masks = model_data["masks"]
print("Keys in masks:")
print(list(masks.keys())[:10])  # 첫 10개의 키


Successfully loaded the file. Keys in the file:
dict_keys(['model_state_dict', 'masks'])
Keys in model_state_dict:
['model.embedding.weight', 'model.perceiver.latents', 'model.perceiver.input_projection.weight', 'model.perceiver.input_projection.bias', 'model.perceiver.blocks.0.cross_attn.in_proj_weight', 'model.perceiver.blocks.0.cross_attn.in_proj_bias', 'model.perceiver.blocks.0.cross_attn.out_proj.weight', 'model.perceiver.blocks.0.cross_attn.out_proj.bias', 'model.perceiver.blocks.0.cross_ln.weight', 'model.perceiver.blocks.0.cross_ln.bias']
Keys in masks:
['text_task_1']


In [11]:
import pickle

file_path = "/home/youlee/perceiver/perceiver/checkpoints_pruned/image_model_1_pruned.pkl"

try:
    with open(file_path, 'rb') as f:
        model_data = pickle.load(f)
    print("Successfully loaded the file. Keys in the file:")
    print(model_data.keys())
except Exception as e:
    print(f"Error while opening the file: {e}")


# model_state_dict 내부 키 확인
state_dict = model_data["model_state_dict"]
print("Keys in model_state_dict:")
print(list(state_dict.keys())[:10])  # 첫 10개의 키

# masks 내부 키 확인
masks = model_data["masks"]
print("Keys in masks:")
print(list(masks.keys())[:10])  # 첫 10개의 키

Successfully loaded the file. Keys in the file:
dict_keys(['model_state_dict', 'masks'])
Keys in model_state_dict:
['model.latents', 'model.input_projection.weight', 'model.input_projection.bias', 'model.blocks.0.cross_attn.in_proj_weight', 'model.blocks.0.cross_attn.in_proj_bias', 'model.blocks.0.cross_attn.out_proj.weight', 'model.blocks.0.cross_attn.out_proj.bias', 'model.blocks.0.cross_ln.weight', 'model.blocks.0.cross_ln.bias', 'model.blocks.0.self_attn_layers.0.self_attn.in_proj_weight']
Keys in masks:
['image_task_1']
