
# 予測したlatentと実際のlatentの誤差の分布を見る


In [None]:
from omegaconf import DictConfig, open_dict
from meg_decoding.models import get_model, Classifier
from meg_decoding.utils.get_dataloaders import get_dataloaders, get_samplers
from meg_decoding.utils.loss import *
from meg_decoding.dataclass.god import GODDatasetBase, GODCollator
from meg_decoding.utils.loggers import Pickleogger
# from meg_decoding.clip_utils.get_embedding import get_language_model
from torch.utils.data.dataset import Subset
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from PIL import Image
import os
from torch.utils.data import DataLoader

In [None]:
def run(args):
    source_dataset = GODDatasetBase(args, 'train', return_label=True)
    outlier_dataset = GODDatasetBase(args, 'val', return_label=True,
                                        mean_X= source_dataset.mean_X,
                                        mean_Y=source_dataset.mean_Y,
                                        std_X=source_dataset.std_X,
                                        std_Y=source_dataset.std_Y
                                    )
    ind_tr = list(range(0, 3000)) + list(range(3600, 6600)) #+ list(range(7200, 21600)) # + list(range(7200, 13200)) + list(range(14400, 20400))
    ind_te = list(range(3000,3600)) + list(range(6600, 7200)) # + list(range(13200, 14400)) + list(range(20400, 21600))
    train_dataset = Subset(source_dataset, ind_tr)
    val_dataset   = Subset(source_dataset, ind_te)
    train_loader = DataLoader(
            train_dataset,
            batch_size= args.batch_size,
            drop_last=True,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            worker_init_fn=seed_worker,
            generator=g,
        )
    test_loader = DataLoader(
            # val_dataset, #
            outlier_dataset,  # val_dataset
            batch_size=50, # args.batch_size,
            drop_last=True,
            shuffle=False,
            num_workers=args.num_workers,
            pin_memory=True,
            worker_init_fn=seed_worker,
            generator=g,
        )
    brain_encoder = get_model(args).to(device) #BrainEncoder(args).to(device)

    weight_dir = os.path.join(os.path.join('/',*args.save_root.split('/')[:-2]), 'weights')
    last_weight_file = os.path.join(weight_dir, "model_last.pt")
    best_weight_file = os.path.join(weight_dir, "model_best.pt")
    if os.path.exists(best_weight_file):
        brain_encoder.load_state_dict(torch.load(best_weight_file))
        print('weight is loaded from ', best_weight_file)
    else:
        brain_encoder.load_state_dict(torch.load(last_weight_file))
        print('weight is loaded from ', last_weight_file)


    classifier = Classifier(args)
    
    Zs = []
    Ys = []
    Ls = []
    brain_encoder.eval()
    for batch in test_loader:
        with torch.no_grad():

            if len(batch) == 3:
                X, Y, subject_idxs = batch
            elif len(batch) == 4:
                X, Y, subject_idxs, Labels = batch
            else:
                raise ValueError("Unexpected number of items from dataloader.")

            X, Y = X.to(device), Y.to(device)

            Z = brain_encoder(X, subject_idxs)  # 0.96 GB
            Zs.append(Z)
            Ys.append(Y)
            Ls.append(Labels)

            testTop1acc, testTop10acc = classifier(Z, Y, test=True)  # ( 250, 1024, 360 )

    Zs = torch.cat(Zs, dim=0)
    Ys = torch.cat(Ys, dim=0)
    Ls = torch.cat(Ls, dim=0).detach().cpu().numpy()
    raw_Es = Zs-Ys
    
    Zs = Zs - Zs.mean(dim=0, keepdims=True)
    Zs = Zs / Zs.std(dim=0, keepdims=True)
    Zs = Zs - Zs.mean(dim=1, keepdims=True)
    Zs = Zs / Zs.std(dim=1, keepdims=True)
    Ys = Ys - Ys.mean(dim=1, keepdims=True)
    Ys = Ys / Ys.std(dim=1, keepdims=True)
    
    normalized_Es  = Zs-Ys
    
    fig, axes = plt.subplots(ncols=2, figsize=(10, 5))
    axes[0].hist(raw_Es.flatten(), bins=100)
    axes[1].hist(normalized_Es.flatten(), bins=100)
