In [3]:
import socket

print("Hostname: " + socket.gethostname())

Hostname: sx-el-121920


In [4]:
import numpy as np
import torch

%load_ext autoreload
%autoreload 2

print("Torch version:", torch.__version__)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Torch version: 1.13.0a0+d0d6b1f


## Load datasets

In [5]:
from utils.mnist_preprocessing import *
from utils.mnist_plotting import *

# dataset parameters
DATASET_BATCH_SIZE = 128
DATASET_SHUFFLE = True

In [6]:
from torchvision import transforms

# initialize datasets
train_set = DatasetMNIST(root='./data',
                            env='train',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

val_set = DatasetMNIST(root='./data',
                            env='val',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

test_set = DatasetMNIST(root='./data',
                            env='test',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

test_set_fool = DatasetMNIST(root='./data',
                            env='test_fool',
                            color=True,
                            opt_postfix="2classes",
                            filter=[5,8],
                            first_color_max_nr=5,
                            transform= transforms.Compose([transforms.ToTensor()]))

# create dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)

val_loader = torch.utils.data.DataLoader(dataset=val_set,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)

test_loader = torch.utils.data.DataLoader(dataset=test_set,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)

test_fool_loader = torch.utils.data.DataLoader(dataset=test_set_fool,
                                            batch_size=DATASET_BATCH_SIZE,
                                            shuffle=DATASET_SHUFFLE,
                                            num_workers=10)


MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists
MNIST dataset already exists


In [9]:
print(f"Number of training samples: {len(train_loader.dataset.data_label_tuples)}")
print(f"Number of validation samples: {len(val_loader.dataset.data_label_tuples)}")
print(f"Number of test samples: {len(test_loader.dataset.data_label_tuples)}")
print(f"Number of test fool samples: {len(test_fool_loader.dataset.data_label_tuples)}")

Number of training samples: 9425
Number of validation samples: 1888
Number of test samples: 1866
Number of test fool samples: 1866


## Set device (For number crunching)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Initialize CLIP

In [15]:
import clip

model, preprocess = clip.load("RN50")
model.cuda().eval()

mnist_template = 'a photo of the number: "{}".'
mnist_classes = ["5", "8"]

In [18]:
from utils.clip_utils import *

train_similarity, train_high_low = clip_mnist_similarity(clip, model, preprocess, mnist_classes, mnist_template, train_loader, "Training color")
val_similarity, val_high_low = clip_mnist_similarity(clip, model, preprocess, mnist_classes, mnist_template, val_loader, "Validation color")
test_similarity, test_high_low = clip_mnist_similarity(clip, model, preprocess, mnist_classes, mnist_template, test_loader, "Test color")

In [19]:
print(f"Accuracy train binary MNIST color 5/8: {clip_mnist_binary_accuracy(train_similarity, train_high_low)}%")
print(f"Accuracy validation binary MNIST color 5/8: {clip_mnist_binary_accuracy(val_similarity, val_high_low)}%")
print(f"Accuracy test binary MNIST color 5/8: {clip_mnist_binary_accuracy(test_similarity, test_high_low)}%")

Accuracy train binary MNIST color 5/8: 93.83%
Accuracy validation binary MNIST color 5/8: 92.71%
Accuracy test binary MNIST color 5/8: 95.95%
