In [26]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from rac.pred_models import CustomTensorDataset, ACCNet
import numpy as np


In [2]:
transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

batch_size = 256

trainset = torchvision.datasets.CIFAR10(root='../datasets/cifar10_original_data', train=True,
                                        download=True, transform=transform)
#trainset.data = trainset.data
#trainset.targets = trainset.targets
X_train = trainset.data
y_train = trainset.targets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../datasets/cifar10_original_data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

X_test = testset.data
y_test = testset.targets

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    

Files already downloaded and verified
Files already downloaded and verified


In [6]:
def clustering_from_clustering_solution(clustering_solution):
    num_clusters = np.max(clustering_solution) + 1
    clustering = [[] for _ in range(num_clusters)]
    for i in range(len(clustering_solution)):
        clustering[clustering_solution[i]].append(i)
    return clustering, num_clusters

def sim_matrix_from_clustering(clustering, N):
    pairwise_similarities = -np.ones((N, N))
    for cind in clustering:
        pairwise_similarities[np.ix_(cind, cind)] = 1
    return pairwise_similarities

In [7]:
train_sol = clustering_from_clustering_solution(y_train[:2000])
train_sim_matrix = sim_matrix_from_clustering(train_sol[0], len(y_train[:2000]))

test_sol = clustering_from_clustering_solution(y_test[:2000])
test_sim_matrix = sim_matrix_from_clustering(test_sol[0], len(y_test[:2000]))

In [8]:
train_sim_matrix.shape

(2000, 2000)

In [16]:
def get_pairs(prop_pos, prop_neg, sim_matrix, data):
    N = sim_matrix.shape[0]
    lower_triangle_indices = np.tril_indices(N, -1)
    ind_pos = np.where(sim_matrix[lower_triangle_indices] == 1)[0]
    ind_neg = np.where(sim_matrix[lower_triangle_indices] == -1)[0]
    num_pos = int(len(ind_pos)*prop_pos)
    num_neg = int(len(ind_neg)*prop_neg)
    ind_pos = np.random.choice(ind_pos, num_pos)
    ind_neg = np.random.choice(ind_neg, num_neg)
    indices = np.concatenate([ind_pos, ind_neg])
    ind1, ind2 = lower_triangle_indices[0][indices], lower_triangle_indices[1][indices]
    x1 = data[ind1]
    x2 = data[ind2]
    y = sim_matrix[ind1, ind2]
    lab1 = np.where(y >= 0)
    lab2 = np.where(y < 0)
    y[lab1] = 1.0
    y[lab2] = 0.0
    return x1, x2, y

In [22]:
x1_train, x2_train, y_train_pairs = get_pairs(0.001, 0.001, train_sim_matrix, X_train)
x1_test, x2_test, y_test_pairs = get_pairs(0.001, 0.001, train_sim_matrix, X_train)

199 1799
199 1799


In [23]:
print(x1_train.shape, x2_train.shape, y_train_pairs.shape)
print(x1_test.shape, x2_test.shape, y_test_pairs.shape)

(1998, 32, 32, 3) (1998, 32, 32, 3) (1998,)
(1998, 32, 32, 3) (1998, 32, 32, 3) (1998,)


In [25]:
cifar_training_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

cifar_test_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

train_dataset = CustomTensorDataset(x1_train, x2_train, y_train_pairs, transform=cifar_training_transform)
test_dataset = CustomTensorDataset(x1_test, x2_test, y_test_pairs, transform=cifar_test_transform)  

In [27]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128)
criterion = nn.BCEWithLogitsLoss()