In [None]:
import copy
import pickle

import matplotlib.pyplot as plt
from tqdm import tqdm

from modularTraining.constantFunctions import get_dataloaders, Parameters
from utils.constant import ATTRIBUTES
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from datamodules.celebadatamodule import CelebADataModule
from lightningmodules.classification import Classification
import numpy as np

import torch
from pytorch_lightning import Trainer

################################ Classification
from dataclasses import dataclass
import os, os.path as osp
from typing import Any, ClassVar, Dict, List, Optional


from torch.utils.data import Dataset, DataLoader
import torch
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as transforms
from modularTraining.constantFunctions import get_idata



def get_prediction(classifier, trainer, images):

    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            # transforms.Normalize(mean, std),
        ]
    )

    data_list = []
    for img in images:
        lbl = torch.zeros(40, 1)
        data_list.append([transform(img), lbl])

    predict_loader = DataLoader(dataset=data_list, batch_size=1, shuffle=False)
    prediction = trainer.predict(classifier, predict_loader)  # without fine-tuning
    all = []
    for idx, data_input in enumerate(prediction):
        pred = data_input[2][0]
        all.append(pred)
    return all


def get_label_distribution(data_loader):
    found = {}
    error=[]
    flag=0

    for lb in attributes:
        found[lb] = 0

    for i, (img, lbl) in tqdm(enumerate(data_loader)):
        predictions = get_prediction(classifier, trainer, img)
        for iter,pred in enumerate(predictions):

            for lb in pred:
                found[lb] += 1

            if 'Young' in pred:
                print('Found Young')
                print(img.shape)
                plt.imshow(img[iter].permute(1,2,0))
                plt.show()
                print(f'Shown {pred}')
                error.append(img[iter])


        # if i>50:
                flag+=1
                # break


        if flag>=10:
            break


        # break

    print('Attribute found')
    # print(found)
    # for lb in found:
    #     found[lb] = found[lb] / (len(data_loader) * 64) * 100
    found = dict(sorted(found.items(), key=lambda item: item[1]))
    print(found)

    return found, error


# def get_idata(name):
#     save_folder = f'/local/scratch/a/rahman89/PycharmProjects/modularCelebA/images256/{name}'
#
#     file = f'{save_folder}/images_10k_I_doMale.pkl'
#     with open(file, 'rb') as f:
#         intv_images = pickle.load(f)
#
#     file = f'{save_folder}/labels_10k_I_doMale.pkl'
#     with open(file, 'rb') as f:
#         intv_labels = pickle.load(f)
#
#     return intv_labels, intv_images


def get_train_test_loaders(dom_name,split_t=0.95, split_v=0.9):
    folder = f'/local/scratch/a/rahman89/PycharmProjects/modularCelebA/images256'
    file = f'{folder}/8_attribute_10k_celeba_{dom_name}.pkl'
    with open(file, 'rb') as f:
        domain_dataset = pickle.load(f)

    folder = f'/local/scratch/a/rahman89/PycharmProjects/modularCelebA/images256'
    file = f'{folder}/images_10k_celeba_{dom_name}.pkl'
    with open(file, 'rb') as f:
        images = pickle.load(f)

    domain_dataset['I'] = images['I']
    dom_age = domain_dataset['Young'].reshape(-1, 1).type(torch.FloatTensor)
    dom_images = domain_dataset['I']

    trainloader, validloader, testloader = get_dataloaders(dom_age, dom_images,split_t=0.95, split_v=0.9)

    return domain_dataset, trainloader, validloader, testloader




# if __name__ == '__main__':
    # loading the classifier



In [None]:
attributes = list(ATTRIBUTES.values())
config = Parameters()
# dataset_module = CelebADataModule(config.data_param)
# dataset_module.setup()
# train_dataloader = dataset_module.train_dataloader()
# val_dataloader = dataset_module.val_dataloader()
# dataset_module.setup(stage="test")
# test_dataloader = dataset_module.test_dataloader()
checkpoint = torch.load(config.inference_param.ckpt_path)
classifier = Classification(config.inference_param)
classifier.load_state_dict(checkpoint["state_dict"])
print('Classifier loaded')
trainer = Trainer(devices=config.hparams.gpu, limit_train_batches=0, limit_val_batches=0)
print('loaded')


In [None]:
# loading interventional data
l1, i1 = get_idata('fake5000')  # getting  1st 5000k data samples
l2, i2 = get_idata('fake10000')  # getting 2nd 5000k data samples
intv_labels = {}
intv_labels['Male'] = np.concatenate([l1['Male'], l2['Male']])
intv_labels['Young'] = np.concatenate([l1['Young'], l2['Young']])
intv_images = np.concatenate([i1['I'], i2['I']])
intv_age = np.array(intv_labels['Young'], dtype='float32').reshape(-1, 1)
# trainloaderI, _, _ = get_dataloaders(intv_age, intv_images, split_t=0.95, split_v=0.99)


intv_dataset = copy.deepcopy(intv_labels)
for key in intv_dataset:
    intv_dataset[key] = torch.tensor(intv_dataset[key])
intv_dataset['I'] = intv_images
male = 0
young = 0
retfo = (intv_dataset['Male'] == male) & (intv_dataset['Young'] == young)
female_old = intv_dataset['Young'][retfo].reshape(-1, 1).type(torch.FloatTensor)
female_image = intv_dataset['I'][retfo]
m0y0_trainloaderI, _, _ = get_dataloaders(female_old, female_image,split_t=0.95, split_v=0.9)
#

print('label distribution in female old interventional data')
distI= get_label_distribution(m0y0_trainloaderI)



In [None]:

intv_labels={}

male_list=[]
young_list=[]
img_list=[]
for m,y in zip([0], [0]):
	ll, ii = get_idata(f'fakeIdom{m}y{y}/images_Idom{m}y{y}.pkl', f'fakeIdom{m}y{y}/labels_Idom{m}y{y}.pkl')  #getting  m0y0 data samples
	male_list.append(ll['Male'])
	young_list.append(ll['Young'])
	img_list.append(ii['I'])

intv_labels['Male'] = np.concatenate(male_list)
intv_labels['Young'] = np.concatenate(young_list)
intv_images =  np.concatenate(img_list)

h= intv_images.shape[0]
rnices= torch.randint(0, h, (h,))
intv_labels['Male']= intv_labels['Male'][rnices]
intv_labels['Young']= intv_labels['Young'][rnices]
intv_images= intv_images[rnices]

#
intv_age = np.array(intv_labels['Young'], dtype='float32').reshape(-1,1)
trainloaderI, validloaderI, testloaderI = get_dataloaders(intv_age, intv_images, split_t=0.95, split_v=0.9)

distI,error= get_label_distribution(trainloaderI)


In [None]:
len(error)

In [None]:
errors =[img.unsqueeze(0) for img in error]
errors= torch.cat(errors).to('cuda')
errors.shape

In [None]:
predictions = get_prediction(classifier, trainer, errors)

In [None]:
for pred in predictions:
    print(pred)

In [None]:
from Classifiers.mobileNet.model import MobileNet
num_labels=1
save_path ="/local/scratch/a/rahman89/PycharmProjects/modularCelebA/Classifiers/new_weights"
cur_domain='dom1'
cur_checkpoint_path = f'{save_path}/{cur_domain}_model_checkpoint.pth'
checkpoint_path = cur_checkpoint_path
classifier2 = MobileNet(num_labels).to('cuda')
checkpoint = torch.load(checkpoint_path)
classifier2.load_state_dict(checkpoint['model_state_dict'])

In [None]:
for img in errors:
    ret= classifier2(img.unsqueeze(0))
    print(ret)
    plt.imshow(img.cpu().permute(1,2,0))
    plt.show()