In [1]:
from __future__ import print_function

import sys
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import time
import numpy as np

import torchvision.datasets as datasets
import math
from easydict import EasyDict as edict
from torch.utils.tensorboard import SummaryWriter

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler


from PIL import Image
import matplotlib.pyplot as plt
import umap
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

from scipy.sparse.linalg import cg
from scipy.sparse import csr_matrix, identity, diags
from scipy.stats import entropy

import faiss
from progressbar import ProgressBar
from helper import * 
import warnings
warnings.simplefilter('ignore')
model_file_loc = './checkpoint/cifar_fixed_weight_label_unlabel_separate_factor_schedule_all_data_ramp_up_1_lr_point_zero_five.t'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Loading faiss with AVX2 support.


In [7]:
print('==> Building model..')
net = ResNet18(pool_len=4, low_dim=128, fixed_weight=True, temperature=.1)

if device == 'cuda':
    # net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    net = MyDataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

# Load the model
state = torch.load(model_file_loc)
net.load_state_dict(state['net'])
net.to(device)

==> Building model..


MyDataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
     

In [48]:
unlabel_index, label_index = generate_subset_of_CIFAR_for_ssl(5000, 25, 1) # Genrate label and unlabel index

#create data loaders
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset_no_augment = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
trainset_no_augment = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_test)

combined_dataset_no_aug = torch.utils.data.ConcatDataset([trainset_no_augment, testset_no_augment])
combined_dataloader_no_aug = torch.utils.data.DataLoader(combined_dataset_no_aug, batch_size=1024, 
                                                  shuffle=False, num_workers=20)
feature_mat, label_arr = sem_sup_feature(net, combined_dataloader_no_aug)
Mat_Label = feature_mat[label_index] #Labelled data matrix
labels = label_arr[label_index] #Corresponding labels of Labelled data matrix

Mat_Unlabel = feature_mat[unlabel_index] # UnLabelled data matrix
rest_label  = label_arr[unlabel_index] # Rest of the lable, won't be used

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [49]:
train_batchsize = 512
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_train)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_train)

combined_dataset = torch.utils.data.ConcatDataset([trainset, testset])
label_sampler = SubsetRandomSampler(label_index)
dataloader_label = torch.utils.data.DataLoader(combined_dataset, 
                                               batch_size=train_batchsize, 
                                               sampler=label_sampler, 
                                               num_workers=5)

unlabel_sampler = SubsetRandomSampler(unlabel_index)
dataloader_unlabel = torch.utils.data.DataLoader(combined_dataset, 
                                               batch_size=train_batchsize, 
                                               sampler=unlabel_sampler, 
                                               num_workers=20)
def mixup_aug(input_main, input_noise, beta_param_1=1, beta_param_2=2):
    
    batch_sz = input_main.size(0)
    beta = np.random.beta(beta_param_1, beta_param_2,size=batch_sz)
    beta = np.minimum(1. - beta, beta) #dont want beta to be larger than 0.5
    beta_array = torch.tensor(beta, dtype=torch.float).to(device)
    
    inputs_mixup = (1. - beta_array.view(batch_sz,1,1,1))*input_main + \
                            beta_array.view(batch_sz,1,1,1)*input_noise
    return inputs_mixup
        
        
        
def sem_sup_mixup(net, label_dl, unlabel_dl):
    net.eval()
    label_dl_iter = iter(label_dl)
    
    unlabel_feature = []
    label_feature = []
    target_list = []
    for i, (inputs, _) in enumerate(unlabel_dl):
        inputs = inputs.to(device)
        try:
            label_input, target = label_dl_iter.next()
        except:
            label_dl_iter = iter(label_dl)
            label_input, target = label_dl_iter.next()
         
        label_input, target = label_input.to(device), target.to(device)
            
        label_batch_sz = label_input.size(0)
        unlabel_batch_sz = inputs.size(0)
        
        label_and_unlabel_inputs = torch.cat([label_input, inputs], dim=0)
        index = np.random.choice(unlabel_batch_sz+label_batch_sz,replace=False, size=label_batch_sz)
        inputs_shuffled = label_and_unlabel_inputs[index,:,:,:]
        inputs_mixup_label = mixup_aug(label_input, inputs_shuffled)
    
        
        idx = torch.randperm(inputs.size(0))
        input_a, input_b = inputs, inputs[idx]
        
        inputs_mixup_unlabel = mixup_aug(input_a, input_b)
        inputs_mixup_unlabel = torch.cat([inputs, inputs_mixup_unlabel], dim=0)
        with torch.no_grad():
            all_pred_unlabel, _ = net(inputs_mixup_unlabel)
            all_pred_label, _ = net(inputs_mixup_label)
            
        unlabel_feature.append(all_pred_unlabel)
        label_feature.append(all_pred_label)
        target_list.append(target)
    
    
    final_label_feature = torch.cat(label_feature, 0).cpu().numpy()
    final_target = torch.cat(target_list, 0).cpu().numpy()
    final_unlabel_feature = torch.cat(unlabel_feature, 0).cpu().numpy()
    
    return final_label_feature, final_target, final_unlabel_feature 
        
    
        
extra_label_feature, extra_target, extra_unlabel_feature = sem_sup_mixup(net, dataloader_label, dataloader_unlabel)    

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [50]:
Mat_Label = np.concatenate((Mat_Label, extra_label_feature))
labels = np.concatenate((labels, extra_target))
Mat_Unlabel = np.concatenate((extra_unlabel_feature, Mat_Unlabel))

In [58]:
# Example of how to do label propagation
start_time = time.time()
# Two types of disnatce function
def trns(x): # taken from this paper https://arxiv.org/pdf/1904.04717.pdf
    return 0 if x < 0 else x**5

def trns1(x): # this is more traditionally used 
    distance = (1-x)/2
    return math.exp(-(distance**2)/.01)

#Stacking features from labelled examples at the begining
MatX = np.vstack((Mat_Label, Mat_Unlabel))
MatX = MatX / np.linalg.norm(MatX, axis=-1)[:, np.newaxis] # not really required since norm of feature is already 1

affinity_matrix = buildGraph(MatX, trns1, 100) # creating sparse affinity matrix
affinity_matrix_time = time.time()
unlabel_data_labels, unlabel_class_prob = labelPropagation(affinity_matrix, Mat_Label, Mat_Unlabel, labels, alpha=.95, n_iter=300) # Doing LP 

unlabel_data_labels_rest = unlabel_data_labels[-len(unlabel_index):]
accuracy = get_acc(unlabel_data_labels_rest, rest_label) # Measuring accuracy 
time_taken = time.time() - start_time
print(f"Accuracy is {accuracy:.4f} and time for affinity matrix {affinity_matrix_time - start_time:.0f} seconds, label propagation time {time.time() - affinity_matrix_time:.0f} seconds")

Accuracy is 0.7729 and time for affinity matrix 23 seconds, label propagation time 122 seconds


In [53]:
unlabel_data_labels, unlabel_class_prob = labelPropagation(affinity_matrix, Mat_Label, Mat_Unlabel, labels, alpha=.95, n_iter=300) # Doing LP 

unlabel_data_labels_rest = unlabel_data_labels[-len(unlabel_index):]
accuracy = get_acc(unlabel_data_labels_rest, rest_label) # Measuring accuracy 
time_taken = time.time() - start_time
print(f"Accuracy is {accuracy:.4f} and time for affinity matrix {affinity_matrix_time - start_time:.0f} seconds, label propagation time {time.time() - affinity_matrix_time:.0f} seconds")

Accuracy is 0.7554 and time for affinity matrix 14 seconds, label propagation time 243 seconds
