In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import os

current_file_path = os.path.abspath("__file__")
current_folder = os.path.dirname(current_file_path)
parent_folder = os.path.dirname(current_folder)
os.chdir(parent_folder)



class Logger(object):
    def __init__(self, name):
        self.terminal = sys.stdout
        self.log = open(name, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)  

    def flush(self):
        self.log.flush()

        
import torch
from torch import nn

import sys
from torchvision import models, transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from vit_pytorch_face import ViT_face, ViT_face_low, ViT_face_up
from vit_pytorch_face import ViTs_face
from vit_pytorch_face import ModifiedViT
import os
from torchvision.datasets import ImageFolder

from util.utils import (
    calculate_prototypes,
    AverageMeter,
    train_accuracy,
    get_unique_classes,
    replace_ffn_with_lora,
    modify_head,
)
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr



def get_cls_token_output(model, x, block_idx=11):
    _, features, embed =  model(x.float(), label=1, prepare_feature=True, training=True)
    xxx = features["layer_{}_ff".format(block_idx)][:,:,:]

    return  xxx 


def denormalize(tensor, mean, std):
    tensor = tensor.cuda()
    mean = torch.tensor(mean).view(1, 3, 1, 1).to(tensor.device)
    std = torch.tensor(std).view(1, 3, 1, 1).to(tensor.device)
    tensor = tensor * std + mean
    return torch.clamp(tensor, 0, 1) 


def normalize(tensor, mean, std):
    mean = torch.tensor(mean).view(1, 3, 1, 1).cuda()
    std = torch.tensor(std).view(1, 3, 1, 1).cuda()
    return (tensor - mean) / std


def loss_function(generated_img, target_features, model, block_idx):

    generated_img_normalized = normalize(generated_img, IMAGENET_MEAN, IMAGENET_STD)
    model_output = get_cls_token_output(model, generated_img_normalized, block_idx)
    return F.mse_loss(model_output, target_features), model_output



class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


BACKBONE_DICT = {
    "VIT_B16": replace_ffn_with_lora(
        ModifiedViT(vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)),
        rank=8,
    ),
}

model = BACKBONE_DICT["VIT_B16"].cuda()


# Mean and standard deviation of ImageNet dataset
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
data_transform = transforms.Compose(
    [
        transforms.Resize(256),  # Adjust the short side to 256
        transforms.CenterCrop(224),  # Center cropped to 224x224
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)
# Load ImageNet official category labels (fixed order)
IMAGENET_CLASSES_PATH = "data/imagenet100/imagenet_classes.txt"
assert os.path.exists(
    IMAGENET_CLASSES_PATH
), "Please download the ImageNet class label file imagenet_classes.txt!"
with open(IMAGENET_CLASSES_PATH) as f:
    imagenet_classes = [line.strip() for line in f.readlines()]


DATA_ROOT = './data/imagenet100'
imagenet_test_dataset = ImageFolder(
    root=os.path.join(DATA_ROOT, "test"), transform=data_transform
)
# Check the order of categories loaded by ImageFolder
test_classes = list(
    imagenet_test_dataset.class_to_idx.keys()
)  # Category names sorted lexicographically
assert set(test_classes).issubset(
    set(imagenet_classes)
), "Test set category is not in ImageNet category!"
# Get the mapping of the current category ID to the original ImageNet category ID
imagenet_class_to_idx = {
    cls_name: idx for idx, cls_name in enumerate(imagenet_classes)
}
current_id_to_original_id = {
    imagenet_test_dataset.class_to_idx[cls]: imagenet_class_to_idx[cls]
    for cls in imagenet_test_dataset.classes
}



model = modify_head(
    model, current_id_to_original_id=current_id_to_original_id
).cuda()


cos_sim = True
if cos_sim:
    import copy
    copy_model = copy.deepcopy(model)

BACKBONE_RESUME_ROOT = None
# old

switch_list = ['ours', 'gslora', 'der++', 'der', 'fdr', 'lwf', 'scrub', 'scrub-u', 'ewc']

for switch in switch_list:

    # switch =  'ours'
    # switch =  'gslora'
    # switch =  'der++'
    # switch =  'der'
    # switch =  'fdr'
    # switch =  'lwf'
    # switch =  'scrub'
    # switch =  'scrub-u'
    # switch =  'ewc'
    if switch == 'gt':
        pass
    elif switch == 'gslora':
        BACKBONE_RESUME_ROOT = './all_baseline/witho_ema/exps_image/no-ema-proto/CLGSLoRA/start80forgetper20lr1e-2beta0.15-20250103155818/task-level/Backbone_task_0.pth'
    elif switch == 'ours':
        BACKBONE_RESUME_ROOT = './LOG/test_image100/4types/20250218/type1_lw2g_back_df_nogumble_searchparams_20240216/lw2g_dr-01_df-005/exps_image/multistep/CLGSLoRA_pure_celoss-dr-coef_1/start80forgetper20lr1e-2beta0.15-20250218114731/task-level/Backbone_task_0.pth'
    elif switch == "der++":
        BACKBONE_RESUME_ROOT = './exps_image/forget-CL/CL-baseline/DER++0.5-start80forget20lr1e-4-20250223225436/task-level/Backbone_task_0.pth'
    elif switch == 'der':
        BACKBONE_RESUME_ROOT = './exps_image/forget-CL/CL-baseline/DER0.1-start80forget20lr1e-4-20250223225436/task-level/Backbone_task_0.pth'
    elif switch == 'fdr':
        BACKBONE_RESUME_ROOT = './exps_image/forget-CL/CL-baseline/FDR10-start80forget20lr1e-3-20250223225436/task-level/Backbone_task_0.pth'
    elif switch == 'lwf':
        BACKBONE_RESUME_ROOT = './exps_image/forget-CL/CL-baseline/Lwf10-start80forget20lr1e-4-20250223225436/task-level/Backbone_task_0.pth'
    elif switch == 'scrub':
        BACKBONE_RESUME_ROOT = './exps_image/forget-CL-baseline/CL-baseline-one/SCRUB-start80forget20lr1e-4-20250223225436/task-level/Backbone_task_0.pth'
    elif switch == 'scrub-u':
        BACKBONE_RESUME_ROOT = './exps_image/forget-CL-baseline/CL-baseline-one/SCRUBsmooth-start80forget20lr1e-4-20250223225436/task-level/Backbone_task_0.pth'
    elif switch == 'ewc':
        BACKBONE_RESUME_ROOT ='./LOG/exps_image/multistep/forget-CL/CL-baseline/EWC10-start80forget20lr1e-4-20250212002751/task-level/Backbone_task_0.pth'
    # BACKBONE_RESUME_ROOT = './all_baseline/witho_ema/exps_image/no-ema-proto/CLGSLoRA/start80forgetper20lr1e-2beta0.15-20250103155818/task-level/Backbone_task_0.pth'
    print('===========strat_resume')

    print('BACKBONE_RESUME_ROOT', BACKBONE_RESUME_ROOT)
    if BACKBONE_RESUME_ROOT:
        print("=" * 60)
        print(BACKBONE_RESUME_ROOT)
        if os.path.isfile(BACKBONE_RESUME_ROOT):
            print("Loading Backbone Checkpoint '{}'".format(BACKBONE_RESUME_ROOT))
            missing_keys, unexpected_keys = model.load_state_dict(
                torch.load(BACKBONE_RESUME_ROOT), strict=False
            )
            if len(missing_keys) > 0:
                print("Missing keys: {}".format(missing_keys))
                print("\n")
                for missing_key in missing_keys:
                    if "lora" not in missing_key:
                        print("\033[31mWrong resume.\033[0m")
                        exit()
            if len(unexpected_keys) > 0:
                print("Unexpected keys: {}".format(unexpected_keys))
                print("\n")
        else:
            print(
                "No Checkpoint Found at '{}' . Please Have a Check or Continue to Train from Scratch".format(
                    BACKBONE_RESUME_ROOT
                )
            )
        print("=" * 60)


    model.eval()
    img_path = './data/imagenet100/train/n02037110/n02037110_920.JPEG' # 小鸟在海上

    save_path = './dip/0307/dip_result/dip_last/image/r_n01685808_499/{}/{}/'.format(switch, block_idx)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    else:
        pass


    log_out = save_path + 'output.log'   
    sys.stdout = Logger(log_out)
    img = Image.open(img_path)
    print('img', img.size)

    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    img_tensor = preprocess(img).unsqueeze(0).cuda().requires_grad_(False) 


    with torch.no_grad():
        target_features = get_cls_token_output(model, img_tensor, block_idx).detach().cuda()  


    if cos_sim:
        with torch.no_grad():
            gt_features = get_cls_token_output(copy_model, img_tensor, block_idx).detach().cuda() 
            copy_gt_features = copy.deepcopy(gt_features)


    # 初始化噪声：假设噪声的维度为 (1, 3, 224, 224) 即224x224图像
    noise = torch.randn_like(img_tensor, requires_grad=False).cuda()


    autoencoder = Autoencoder().cuda()
    optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

    num_iterations = 9100
    for i in range(num_iterations):
        optimizer.zero_grad()
        
        generated_img = autoencoder(noise)
        
        loss, embed = loss_function(generated_img, target_features, model, block_idx)

        loss.backward()
        
        optimizer.step()
        
        if i % 2000 == 1000:
            print(f"Iteration {i}, Loss: {loss.item()}")
            

            copy_embed = embed.detach().clone()
            print('embed', copy_embed.shape)
            print('copy_gt_features', copy_gt_features.shape)
            similarity = F.cosine_similarity(copy_embed[:,0,:], copy_gt_features[:,0,:], dim=1)
            print(f"Cosine Similarity: {similarity.item():.4f}")

            generated_img = generated_img.detach().clamp(0, 1) 


            original_img_denorm = denormalize(img_tensor.detach(), IMAGENET_MEAN, IMAGENET_STD).clamp(0, 1)


            generated_img_np = generated_img.squeeze(0).permute(1, 2, 0).cpu().numpy()
            original_img_np = original_img_denorm.squeeze(0).permute(1, 2, 0).cpu().numpy()


            ssim_score = ssim(original_img_np, generated_img_np, data_range=1, win_size=11, channel_axis=2)
            psnr_score = psnr(original_img_np, generated_img_np, data_range=1.0)


            print(f"SSIM: {ssim_score:.4f}, PSNR: {psnr_score:.2f} dB")

            generated_img = generated_img.squeeze(0)  
            generated_img = generated_img.permute(1, 2, 0)  
            generated_img = generated_img.cpu().numpy()  


            generated_img = generated_img.astype('float32')


            plt.imshow(generated_img)
            plt.axis('off')

            plt.savefig(save_path + 'generated_image_{}.png'.format(i), format='png', bbox_inches='tight')  
