# Feature Extractor 

In [1]:
from torchvision.datasets import CelebA
from torchvision.models import resnet50
from torchvision.transforms import Resize, ToTensor, Compose,Normalize
from torch.nn import Identity
from torch.utils.data import DataLoader
import torch
from torch.utils.data import TensorDataset
from easydict import EasyDict as edict
import torchvision.transforms as tt
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from tqdm.notebook import tqdm

def mnist_cifar(root, split, binarize=False):

    all_splits = torch.load(root)
    sp = split
    if split == "id":
        sp = 'train'
    ds = all_splits[sp]
    if binarize:
        ds['targets'] = (ds['targets'] > 4).float()

    dataset = TensorDataset(ds['data'], ds['targets'], ds['group'])

    if split in ['train', 'id']:
        generator1 = torch.Generator().manual_seed(42)
        dss = random_split(dataset, [9000, 1000], generator=generator1 )
        if split == 'train':
            return dss[0]
        elif split == 'id':
            return dss[1]

    return dataset

def get_dataset(split, corr, root="../../datasets/MNISTCIFAR",n_dims=1 ,**kwargs):

    datapath = root
    if n_dims == 10:
        datapath += f"/MNIST_CIFAR_{corr}.pth"
    else:
        datapath += f"/MNIST_CIFAR_binary_{corr}.pth"
    #print(datapath)
    a = mnist_cifar(datapath, split=split)
    return  a

In [2]:
# Define the mean and standard deviation for ImageNet normalization
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
dataset = "mnistcifar"
split = "test"
# Create a Normalize transform
t = Compose([Resize((224,224)), ToTensor(), Normalize(mean=mean, std=std)])


In [4]:
corr = 0.9
corrs = [0.0, 0.25, 0.5, 0.75, 1.0]
split="train"
for corr in tqdm(corrs,total=len(corrs)):
    if dataset == "mnistcifar":
        train_ds = get_dataset(split,corr)
    elif dataset == "celeba":
        train_ds = CelebA(root="../../datasets",target_type="attr",split=split, download=False,transform=t)

    train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
    # Create resnet extractor
    extractor = resnet50(weights="IMAGENET1K_V1").cuda()
    extractor.fc = Identity()
    features = []
    total_batches = len(train_dl)
    labels = []
    groups = []
    
    for n_batch, (x, y, g) in tqdm(enumerate(train_dl), total=total_batches, desc="Processing"):
        x = x.cuda()
        labels.extend(y.clone().cpu())
        groups.extend(g.clone().cpu())
        features.append(extractor(x).detach().clone().cpu())

    feats_torch = torch.cat(features, dim=0)
    labels = [x.item() for x in labels]
    groups = [x.item() for x in groups]
    torch.save( {"x": feats_torch, "y": labels, "g": groups}, f"{dataset}_{corr}_imgnet_2048_{split}.pth")

  0%|          | 0/5 [00:00<?, ?it/s]

Processing:   0%|          | 0/282 [00:00<?, ?it/s]

Processing:   0%|          | 0/282 [00:00<?, ?it/s]

Processing:   0%|          | 0/282 [00:00<?, ?it/s]

Processing:   0%|          | 0/282 [00:00<?, ?it/s]

Processing:   0%|          | 0/282 [00:00<?, ?it/s]