In [None]:
from dataset_places8_2classes import PlacesDataset
from torch.utils.data import DataLoader
import clip
import torch
from tqdm import tqdm
import os

idx_to_classname ={0: 'bedroom', 1: 'childs_room'}
n_classes = len(idx_to_classname.keys())
print(n_classes)
data_path = "../adversarial-sets/data/Places8_paths_and_labels_complete_train.npy"
outpath = "./outputs"
outpath = os.path.join(outpath, "places8_bedroom_childs_room")
os.makedirs(outpath, mode=777, exist_ok=True)

places_ds = [PlacesDataset(data_path,
                           onlylabels=[k]) for k in range(n_classes)]
batch_size = 64
train_dataloaders_class = {k: DataLoader(places_ds[k],
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=8) for k in range(n_classes)}

for k in range(n_classes):
    print(f"\nDataloader: {batch_size} batch size | {len(train_dataloaders_class[k])} batches | {len(train_dataloaders_class[k].dataset)} images")

print(clip.available_models())
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
model, preprocess = clip.load("ViT-B/32", device=device)
print(model)

In [None]:
print(preprocess)

In [None]:
for i in range(n_classes):
    image_features = []
    text_features = []
    for inputs, labels, _ in tqdm(train_dataloaders_class[i]):
        inputs = inputs.to(device)
        text_inputs = torch.cat([clip.tokenize(
            f"a photo of a {idx_to_classname[label.item()]}", truncate=True) for label in labels]).to(device)

        # Calculate features
        with torch.no_grad():
            img_feat_batch = model.encode_image(inputs).cpu().data.numpy()
            text_feat_batch = model.encode_text(text_inputs).cpu().data.numpy()
            for idx, _ in enumerate(img_feat_batch):
                image_features.append(img_feat_batch[idx].tolist())
                text_features.append(text_feat_batch[idx].tolist())
    
    image_features = np.asarray(image_features)
    text_features = np.asarray(text_features)
    np.save(f"{outpath}/places8_image_features_clip_class_{i}.npy", image_features)
    np.save(f"{outpath}/places8_text_features_clip_class_{i}.npy", text_features)


In [None]:
print(len(image_features))