# Setting

In [27]:
import logging
import os
import pdb
import math
import glob
import random
import time
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image
from sklearn.preprocessing import LabelEncoder

from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

import torchvision.transforms as T

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.nn.parameter import Parameter

from tqdm import tqdm
import sys

import copy
from scipy import spatial
import csv

In [28]:
def patchify(img_tensor, patch_size=16):
    """
    img_tensor: (C, H, W) 형태 (예: (3, 224, 224))
    patch_size: 패치 크기 (16, 16)
    return: (num_patches, patch_dim)
            예) (196, 768)  # (H/16)*(W/16)=14*14=196, 768=16*16*3
    """
    C, H, W = img_tensor.shape
    assert H % patch_size == 0 and W % patch_size == 0, "이미지 크기는 patch_size로 나누어 떨어져야 함"

    # unfold로 (patch_size, patch_size)씩 잘라내기
    patches = img_tensor.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
 
    patches = patches.permute(1, 3, 0, 2, 4).contiguous()

    num_patches_h = H // patch_size
    num_patches_w = W // patch_size
    num_patches = num_patches_h * num_patches_w

    patches = patches.view(num_patches, -1) 
    return patches

def get_patch_coords(num_patches_h, num_patches_w):

    y_coord = torch.linspace(0, 1, steps=num_patches_h)
    x_coord = torch.linspace(0, 1, steps=num_patches_w)
    grid_y, grid_x = torch.meshgrid(y_coord, x_coord, indexing='ij')  # (14,14) each

    coords = torch.stack([grid_x, grid_y], dim=-1)  # (14,14,2)
    coords = coords.view(-1, 2)                     # (196, 2)
    return coords

In [29]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, crop_size=32, patch_size=16):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.label_encoder = LabelEncoder()
        self.crop_size = crop_size
        self.patch_size = patch_size

        # 데이터 및 레이블 추출
        labels = []
        for folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, folder)
            if os.path.isdir(folder_path):
                image_files = glob.glob(os.path.join(folder_path, "*.jpg"))
                if len(image_files) == 1:
                    image_path = image_files[0]
                else:
                    raise ValueError(f"폴더 {folder}에 JPG 파일이 하나가 아닙니다.")

                label_path = os.path.join(folder_path, "label.txt")
                if os.path.exists(label_path):
                    with open(label_path, "r") as f:
                        label = f.read().strip()
                        labels.append(label)
                        self.data.append((image_path, label))
                else:
                    raise FileNotFoundError(f"폴더 {folder}에 label.txt가 없습니다.")

        self.label_encoder.fit(labels)
        self.data = [(image_path, self.label_encoder.transform([label])[0]) 
                     for image_path, label in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)  # (3, 224, 224)
        image = crop(image, crop_size=self.crop_size)
        patches = patchify(image, patch_size=self.patch_size)  # (num_patches, patch_dim)

        H, W = image.shape[1], image.shape[2]  
        num_patches_h = H // self.patch_size   
        num_patches_w = W // self.patch_size   
        coords = get_patch_coords(num_patches_h, num_patches_w)  

        combined = torch.cat([patches, coords], dim=1)

        label = torch.tensor(label, dtype=torch.long)
        return combined, label

In [30]:
class PerceiverBlock(nn.Module):
    """
    - Cross Attention (latents -> x)
    - 이어서 Self Attention (latent들 끼리)
    - 보통은 LayerNorm, MLP(FeedForward) 등을 곁들여 residual branch를 구성
    """
    def __init__(self, latent_dim, n_heads=8, self_attn_layers=1):
        super().__init__()
        # Cross Attention
        self.cross_attn = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=n_heads)
        self.cross_ln = nn.LayerNorm(latent_dim)  # 잊지 말고 layernorm

        # Self Attention 여러 층
        self.self_attn_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=latent_dim, nhead=n_heads)
            for _ in range(self_attn_layers)
        ])

    def forward(self, latents, x):
        # latents, x: (T, B, dim) 형태로 가정 (주의!)
        # Perceiver 원리상 latents는 query, x는 key/value

        # 1) Cross Attention
        updated_latents, _ = self.cross_attn(latents, x, x)
        latents = latents + updated_latents        # Residual
        latents = self.cross_ln(latents)           # LayerNorm

        # 2) Self Attention 반복
        for layer in self.self_attn_layers:
            latents = layer(latents)  # 내부적으로 residual/LayerNorm 포함

        return latents


In [31]:
class Perceiver(nn.Module):
    def __init__(self, input_dim, latent_dim, latent_size, num_classes,
                 num_blocks, self_attn_layers_per_block=1):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(latent_size, latent_dim))
        self.input_projection = nn.Linear(input_dim, latent_dim)

        # 반복될 PerceiverBlock을 여러 개 쌓는다.
        self.blocks = nn.ModuleList([
            PerceiverBlock(
                latent_dim=latent_dim,
                n_heads=8,
                self_attn_layers=self_attn_layers_per_block
            )
            for _ in range(num_blocks)
        ])

        self.output_layer = nn.Linear(latent_dim, num_classes)

    def forward(self, x):
        """
        x: (B, T, F) = (배치, 시퀀스길이, 피처차원)
        """
        B, T, F = x.size()
        x = self.input_projection(x)                 # (B, T, latent_dim)

        # latents: (latent_size, latent_dim) -> 배치 차원 확장 (B, latent_size, latent_dim)
        latents = self.latents.unsqueeze(0).expand(B, -1, -1)

        # MultiHeadAttention은 (seq, batch, dim) 순서를 권장하므로 permute
        x = x.permute(1, 0, 2)        # (T, B, latent_dim)
        latents = latents.permute(1, 0, 2)  # (latent_size, B, latent_dim)

        # 여러 개의 PerceiverBlock 반복
        for block in self.blocks:
            latents = block(latents, x)

        # 최종 latents: (latent_size, B, latent_dim)
        latents = latents.permute(1, 0, 2).mean(dim=1)  # (B, latent_dim)
        return self.output_layer(latents)


# Load Data

# Load Model

In [32]:
root_dir = '/home/Minju/Perceiver/loader/'
loader_dir = '/home/Minju/Perceiver/loader/'

batch_size = 32

In [33]:
text_models = []
image_models = []
valid_loaders = []

for i in range (6):
    #text_model = torch.load(root_dir + f'model_text_{i}')
    #text_models.append(text_model)

    img_model = torch.load(root_dir + f'model_image_{i}.pkl')
    image_models.append(img_model)

    with open(loader_dir+f'val_loader_{i}.pkl', 'rb') as f:
        loaded_valid_dataset = pickle.load(f)
    valid_loader = DataLoader(loaded_valid_dataset, batch_size=batch_size, shuffle=False)
    valid_loaders.append(valid_loader)

  img_model = torch.load(root_dir + f'model_image_{i}.pkl')


In [34]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define ModelDiff

In [35]:
num_classes = -1
lr = 0.1
batch_size = 32
val_batch_size = 100
workers = 24
weight_decay = 4e-5
dataset_name = ''
train_path = ''
val_path = ''
cuda = True
seed = 1
epochs = 160
restore_epoch = 0
save_folder = ''
load_folder = ''
one_shot_prune_perc = 0.5
mode = ''
logfile = ''
initial_from_task = ''

In [36]:
DATASETS = [
    'Opinion','Art & Design','Television',
    'Music','Travel','Real Estate',
    'Books','Theater','Health',
    'Sports','Science','Food',
    'Fashion & Style','Movies','Technology',
    'Dance', 'Media', 'Style'
]

In [37]:
epsilon = 0.1
max_iterations = 100

In [38]:
target_id = 0

# 유사도검색

## Approach 1: 특정 input data로 유사도 검증

In [39]:
def compute_ddv_cos(model1, model2, inputs):
    global outputs
    global outputs2
        
    with torch.no_grad():
        dists = []
        outputs = model1(torch.Tensor(inputs).cuda()).to('cpu').tolist()
        n_pairs = int(len(list(inputs)) / 2)
        for i in range(n_pairs):
            ya = outputs[i]
            yb = outputs[i + n_pairs]
            dist = spatial.distance.cosine(ya, yb)
            dists.append(dist)

        dists2 = []
        outputs2 = model2(torch.Tensor(inputs).cuda()).to('cpu').tolist()
        n_pairs2 = int(len(list(inputs)) / 2)
        for i in range(n_pairs2):
            ya = outputs2[i]
            yb = outputs2[i + n_pairs]
            dist = spatial.distance.cosine(ya, yb)
            dists2.append(dist)
    return np.array(dists), np.array(dists2)

In [40]:
def compute_ddv_euc(model1, model2, inputs):
    global outputs
    global outputs2
        
    with torch.no_grad():
        dists = []
        outputs = model1(torch.Tensor(inputs).cuda()).to('cpu').tolist()
        n_pairs = int(len(list(inputs)) / 2)
        for i in range(n_pairs):
            ya = outputs[i]
            yb = outputs[i + n_pairs]
            dist = spatial.distance.euclidean(ya, yb) # dist = spatial.distance.cosine(ya, yb)
            dists.append(dist)

        dists2 = []
        outputs2 = model2(torch.Tensor(inputs).cuda()).to('cpu').tolist()
        n_pairs2 = int(len(list(inputs)) / 2)
        for i in range(n_pairs2):
            ya = outputs2[i]
            yb = outputs2[i + n_pairs]
            dist = spatial.distance.euclidean(ya, yb) # dist = spatial.distance.cosine(ya, yb)
            dists2.append(dist)
    return np.array(dists), np.array(dists2)

In [41]:
##### compute_similarity #####
def compute_sim_cos(ddv1, ddv2):
    return spatial.distance.cosine(ddv1, ddv2)    

In [43]:
ddvcc_list = []
ddvec_list = []

for task_id in range(6):
    if task_id == target_id:
        continue
    ddv1, ddv2 = compute_ddv_cos(image_models[target_id], image_models[task_id], valid_loaders[target_id])
    ddv_distance = compute_sim_cos(ddv1, ddv2)
    print('DDV cos-cos [%d => %d] %.5f'%(task_id, target_id, ddv_distance))
    ddvcc_list.append(ddv_distance)

    # DDV-EC
    ddv1, ddv2 = compute_ddv_euc(image_models[target_id], image_models[task_id], valid_loaders[target_id])
    ddv_distance = compute_sim_cos(ddv1, ddv2)
    print('DDV euc-cos [%d => %d] %.5f'%(task_id, target_id, ddv_distance))
    ddvec_list.append(ddv_distance)

TypeError: new(): data must be a sequence (got DataLoader)

## Approach 2: latent vector로 유사도 검증