In [1]:
from Phi import Extractor_CLIP, Phi
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
from torch.autograd import Variable

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
from torchvision import transforms

from Generate_data_augmentation import *

from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
from wilds.datasets.domainnet_dataset import DomainNetDataset
from wilds.datasets.waterbirds_dataset import WaterbirdsDataset
from wilds.datasets.wilds_dataset import WILDSSubset

from PIL import Image

cuda = True if torch.cuda.is_available() else False

import warnings
warnings.filterwarnings('ignore')

In [2]:
z_hidden = 20
dataset_name = "waterbirds"

In [3]:
extractor_model = Extractor_CLIP(z_hidden)

You are using a model of type clip to instantiate a model of type clip_vision_model. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'logit_scale', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_mo

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.


In [4]:
batch_size = 1280
root_dir = "/hdd2/wilds_data"

In [5]:
class CustomDomainnet(DomainNetDataset):
    def __init__(self, **dataset_kwargs):
        super().__init__(split_scheme ='official', use_sentry=True, **dataset_kwargs)
    
    def get_input(self, idx):
        img_path = os.path.join(self.data_dir, self._input_image_paths[idx])
        img = Image.open(img_path).convert("RGB")
        label = self._y_array.numpy()[idx]
        return img, label
    
    def get_subset(self, split, frac=1.0, transform=None):
        if split not in self.split_dict:
            raise ValueError(f"Split {split} not found in dataset's split_dict.")

        split_mask = self.split_array == self.split_dict[split]
        split_idx = np.where(split_mask)[0]

        if frac < 1.0:
            # Randomly sample a fraction of the split
            num_to_retain = int(np.round(float(len(split_idx)) * frac))
            split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])

        return WILDSSubset(self, split_idx, transform, do_transform_y=True)

In [6]:
class CustomWaterbirds(WaterbirdsDataset):
    def __init__(self, **dataset_kwargs):
        super().__init__(split_scheme ='official', **dataset_kwargs)
    
    def get_input(self, idx):
        img_path = os.path.join(self.data_dir, self._input_array[idx])
        img = Image.open(img_path).convert("RGB")
        label = self._y_array.numpy()[idx]
        return img, label
    
    def get_subset(self, split, frac=1.0, transform=None):
        if split not in self.split_dict:
            raise ValueError(f"Split {split} not found in dataset's split_dict.")

        split_mask = self.split_array == self.split_dict[split]
        split_idx = np.where(split_mask)[0]

        if frac < 1.0:
            # Randomly sample a fraction of the split
            num_to_retain = int(np.round(float(len(split_idx)) * frac))
            split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])

        return WILDSSubset(self, split_idx, transform, do_transform_y=True)


In [7]:
class CustomCamelyon(Camelyon17Dataset):
    def __init__(self, **dataset_kwargs):
        super().__init__(split_scheme ='official', **dataset_kwargs)
    
    def get_input(self, idx):
        img_path = os.path.join(self.data_dir, self._input_array[idx])
        img = Image.open(img_path).convert("RGB")
        label = self._y_array.numpy()[idx]
        return img, label
    
    def get_subset(self, split, frac=1.0, transform=None):
        if split not in self.split_dict:
            raise ValueError(f"Split {split} not found in dataset's split_dict.")

        split_mask = self.split_array == self.split_dict[split]
        split_idx = np.where(split_mask)[0]

        if frac < 1.0:
            # Randomly sample a fraction of the split
            num_to_retain = int(np.round(float(len(split_idx)) * frac))
            split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain])

        return WILDSSubset(self, split_idx, transform, do_transform_y=True)

In [8]:
if dataset_name == 'waterbirds':
    dataset=CustomWaterbirds(download=False, root_dir=root_dir)
elif dataset_name == 'domainnet':
    dataset=CustomDomainnet(download=False, root_dir=root_dir)
else:
    dataset=CustomCamelyon(download=False, root_dir=root_dir)

In [9]:
if dataset_name == 'domainnet':
    dataset_orig = get_dataset(dataset=dataset_name, download=False, root_dir=root_dir, use_sentry=True)
else:
    dataset_orig = get_dataset(dataset=dataset_name, download=False, root_dir=root_dir)

In [10]:
dataset_orig.split_dict

{'train': 0, 'val': 1, 'test': 2}

In [11]:
categories_map_file = os.path.join(root_dir, f"{dataset_name}_v1.0", "metadata.csv")
categories = pd.read_csv(categories_map_file)
categories.sample(5)

Unnamed: 0,img_id,img_filename,y,split,place,place_filename
6937,6938,119.Field_Sparrow/Field_Sparrow_0111_113899.jpg,0,0,0,/f/forest/broadleaf/00001033.jpg
9028,9029,154.Red_eyed_Vireo/Red_Eyed_Vireo_0101_156988.jpg,0,0,1,/o/ocean/00003200.jpg
4141,4142,072.Pomarine_Jaeger/Pomarine_Jaeger_0024_61281...,1,0,1,/o/ocean/00002901.jpg
3738,3739,065.Slaty_backed_Gull/Slaty_Backed_Gull_0020_7...,1,0,1,/o/ocean/00001006.jpg
1683,1684,030.Fish_Crow/Fish_Crow_0079_26030.jpg,0,0,0,/f/forest/broadleaf/00001217.jpg


In [12]:
np.unique(categories['split'].tolist())

array([0, 1, 2])

In [13]:
labels = np.unique(categories['y'].tolist())
labels

array([0, 1])

In [14]:
class MyCompose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, tar):
        for i, t in enumerate(self.transforms):
            try:
                if t._get_name() == "Custom":
                    img = t(img, tar)
                else:
                    img = t(img[0])
            except Exception as e:
                img = t(img)
        return img, tar

In [15]:
def get_train_dataloader(transform):
    dataset_ = dataset.get_subset(split="train", transform=transform)
    return DataLoader(dataset_, batch_size=batch_size,
                        shuffle=False)

In [16]:
cmap_mapping = generate_transformation_mapping(labels, cmaps)
trainloader_color = get_train_dataloader(
    MyCompose(
        [transforms.Resize((448, 448)), Change_cmap(cmap_mapping) , transforms.ToTensor()]
    ))

rotation_mapping = generate_transformation_mapping(labels, angles)
trainloader_rotation = get_train_dataloader(
    MyCompose(
        [transforms.Resize((448, 448)), Rotate_Image(rotation_mapping), transforms.ToTensor()]
    ))


zoom_mapping = generate_transformation_mapping(labels, zoom_factors)
trainloader_zoom = get_train_dataloader(
    MyCompose(
        [transforms.Resize((448, 448)), Zoom_Image(zoom_mapping), transforms.ToTensor()]
    ),)

shift_mapping = generate_transformation_mapping(labels, shift_factors)
trainloader_shift = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), Shift_Image(shift_mapping), transforms.ToTensor()]
    ),)


trainloader_mix_1 = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), 
         Rotate_Image(rotation_mapping), 
         Shift_Image(shift_mapping), transforms.ToTensor()]
    ),)


traindata_orig = dataset_orig.get_subset(split="train", transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]))
trainloader_orig = DataLoader(traindata_orig, batch_size=batch_size,
                        shuffle=False)

In [17]:
trainloader_mix_2 = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), Zoom_Image(zoom_mapping), Shift_Image(shift_mapping), transforms.ToTensor()]
    ),)

In [18]:
cmap_mapping = generate_transformation_mapping(labels, cmaps)
trainloader_color_2 = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), Change_cmap(cmap_mapping) , transforms.ToTensor()]
    ),)

In [19]:
trainloader_gaussian = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), Apply_Filter('gaussian'), transforms.ToTensor()]))

In [20]:
trainloader_spline = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), Apply_Filter('spline'), transforms.ToTensor()]))
trainloader_uniform = get_train_dataloader(MyCompose(
        [transforms.Resize((448, 448)), Apply_Filter('uniform'), transforms.ToTensor()]))

In [21]:
def show_samples(loader, cols=6):
    fig, axs = plt.subplots(1, cols, figsize=(10,12), constrained_layout=True)
    for batch_idx, x in enumerate(loader):
        if batch_idx > 0:
            break
        imgs, labels = x[0], x[1]
        n_images = imgs.shape[0]
        cols = 6
        axs = axs.flatten()
        random_idx = np.random.choice(np.arange(imgs.shape[0]), 6, replace=False)
        img_sample = imgs[random_idx]
        for i, ax in enumerate(axs):
            img = img_sample[i]
#             print(img.shape)
            axs[i].imshow(img.permute(1,2,0))
            y = labels[random_idx][i].item()
            axs[i].set_title(y, fontsize=10)
            axs[i].set_xticks([])
            axs[i].set_yticks([])
    plt.show()

In [22]:
# show_samples(trainloader_uniform, cols=6)

In [23]:
# show_samples(trainloader_gaussian, cols=6)

In [24]:
# show_samples(trainloader_color, cols=6)

In [25]:
# show_samples(trainloader_rotation, cols=6)

In [26]:
# show_samples(trainloader_zoom, cols=6)

In [27]:
# show_samples(trainloader_shift, cols=6)

In [28]:
# show_samples(trainloader_orig, cols=6)

In [29]:
# show_samples(trainloader_mix_1, cols=6)

In [30]:
# for i, (c, r, z, s, o) in tqdm(enumerate(zip(trainloader_color, trainloader_rotation, trainloader_zoom, trainloader_shift, trainloader_orig))):
#     if i > 1:
#         break
#     colored = c[0].cpu().detach().numpy()
#     rotated = r[0].cpu().detach().numpy()
#     zoomed = z[0].cpu().detach().numpy()
#     shifted = s[0].cpu().detach().numpy()
#     original = o[0].cpu().detach().numpy()
#     plt.figure(figsize=(12, 8))
#     for j, item in enumerate(colored):
#         if j >= 9: break
#         plt.subplot(5, 9, j+1)
#         im2display = item.transpose((1,2,0))
#         im2display = (im2display * 255).astype(np.uint8)
#         plt.xticks([])
#         plt.yticks([])
#         plt.imshow(im2display)

#     for j, item in enumerate(rotated):
#         if j >= 9: break
#         plt.subplot(5, 9, 9+j+1)
#         im2display = item.transpose((1,2,0))
#         im2display = (im2display * 255).astype(np.uint8)
#         plt.xticks([])
#         plt.yticks([])
#         plt.imshow(im2display)
        
#     for j, item in enumerate(zoomed):
#         if j >= 9: break
#         plt.subplot(5, 9, 18+j+1)
#         im2display = item.transpose((1,2,0))
#         im2display = (im2display * 255).astype(np.uint8)
#         plt.xticks([])
#         plt.yticks([])
#         plt.imshow(im2display)
    
    
#     for j, item in enumerate(shifted):
#         if j >= 9: break
#         plt.subplot(5, 9, 27+j+1)
#         im2display = item.transpose((1,2,0))
#         im2display = (im2display * 255).astype(np.uint8)
#         plt.xticks([])
#         plt.yticks([])
#         plt.imshow(im2display)
    
#     for j, item in enumerate(original):
#         if j >= 9: break
#         plt.subplot(5, 9, 36+j+1)
#         im2display = item.transpose((1,2,0))
#         im2display = (im2display * 255).astype(np.uint8)
#         plt.xticks([])
#         plt.yticks([])
#         plt.imshow(im2display)
#     plt.tight_layout()
#     plt.show()

In [31]:
def extract_features(dataloader):
    phi = Phi(extractor_model)
    features = phi.get_z_features(dataloader)
    return features

In [32]:
# colored_features = extract_features(trainloader_color)
# rotated_features = extract_features(trainloader_rotation)
# zoom_features = extract_features(trainloader_zoom)
# shift_features = extract_features(trainloader_shift)
orig_features_full, orig_features_pca, feature_mapping, components = extract_features(trainloader_orig)

4it [03:35, 53.92s/it]


In [33]:
# gaussian_features = extract_features(trainloader_gaussian)
# uniform_features = extract_features(trainloader_uniform)

In [34]:
# mix_1_features = extract_features(trainloader_mix_1)
# mix_2_features = extract_features(trainloader_mix_2)

In [35]:
# colored_features_2 = extract_features(trainloader_color_2)

In [36]:
test_data = dataset_orig.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
testloader = DataLoader(test_data, batch_size=batch_size,
                        shuffle=False)

In [37]:
def get_metadata(loader):
    metadata_all = []
    for i, (_, _, metadata) in tqdm(enumerate(loader)):
        metadata_all.append(metadata)
    metadata_all = np.vstack(metadata_all)
    return metadata_all

In [38]:
test_metadata = get_metadata(testloader)

5it [00:25,  5.10s/it]


In [39]:
test_features_full, tes_features_pca, _, _ = extract_features(testloader)

5it [04:18, 51.66s/it]


In [40]:
# classes_names = ""
# for i, y in enumerate(labels):
#     str_ = categories.loc[categories['y'] == i]['category'].tolist()[0]
#     classes_names += str_
#     if i < len(labels)-1:
#         classes_names += "_"
# classes_names

In [41]:
suffix = "CLIP"
store_path = os.path.join("artifacts", "extracted_features", dataset_name, f"{suffix}", f"{len(labels)}_class")
if not os.path.isdir(store_path):
    os.makedirs(store_path)

In [42]:
import pickle

def store_mapping(mapping):
    with open(os.path.join(store_path,'pca_feature_mapping.pickle'), 'wb') as handle:
        pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [43]:
store_mapping(feature_mapping)

In [44]:
def store_features(features, mode, type_):
    np.save(os.path.join(store_path,f"{type_}_{mode}_{features.shape[1]}.npy"), features)

In [45]:
# store_features(colored_features, "train", "color")
# store_features(rotated_features, "train","rotation")
# store_features(zoom_features, "train","zoom")
# store_features(shift_features, "train","shift")
store_features(orig_features_full, "train","orig_full")
store_features(orig_features_pca, "train","orig_pca")

In [46]:
# store_features(mix_1_features, "train", "mix_1")
# store_features(mix_2_features, "train", "mix_2")

In [47]:
# store_features(colored_features_2, "train", "color_2")

In [48]:
# store_features(gaussian_features, "train", "gaussian")
# store_features(uniform_features, "train", "uniform")

In [49]:
store_features(test_features_full, "test", "orig_full")
store_features(tes_features_pca, "test", "orig_pca")

In [50]:
store_features(test_metadata, "test", "metadata")

In [51]:
store_features(components, "train", "pca_components")

In [52]:
store_path

'artifacts/extracted_features/waterbirds/CLIP/2_class'