In [1]:
import numpy as np
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import pandas as pd

random_state = 0
np.random.seed(random_state)

In [2]:
class Strategy:
    def __init__(self, dataset):
        self.dataset = dataset

    def query(self, n):
        pass

    def update(self, pos_idxs, neg_idxs=None):
        self.dataset.labeled_idxs[pos_idxs] = True
        if neg_idxs:
            self.dataset.labeled_idxs[neg_idxs] = False
    def train(self):
        labeled_idxs, labeled_data = self.dataset.get_labeled_data()
        print(f'Training on {labeled_idxs}')

    def predict(self, data):
        preds = np.ones(data.shape[0])
        return preds

class RandomSampling(Strategy):
    def __init__(self, dataset):
        super(RandomSampling, self).__init__(dataset)

    def query(self, n):
        return np.random.choice(np.where(self.dataset.labeled_idxs==0)[0], n, replace=False)

# dataset preparation
class Data:
    def __init__(self, data: pd.DataFrame, 
                 features: list, target: str, random_state: int, test_size: float, handler):
        # extract data
        X = data[features].values
        y = data[target].values
        # scale the X data
        print('Scaling the data')
        X = (X - X.mean(axis=0)) / X.std(axis=0)
        # split data
        print('Splitting the data')
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
        # handler
        self.handler = handler
        self.n_pool = len(self.X_train)
        self.n_test = len(self.X_test)
        self.labeled_idxs = np.zeros(self.n_pool, dtype=bool)
        
    def initialize_labels(self, num):
        '''initialize the labels of the pool data'''
        tmp_idxs = np.arange(self.n_pool)
        np.random.shuffle(tmp_idxs)
        self.labeled_idxs[tmp_idxs[:num]] = True
    
    def get_labeled_data(self):
        '''return the labeled data'''
        labeled_idxs = np.arange(self.n_pool)[self.labeled_idxs]
        return labeled_idxs, self.handler(self.X_train[labeled_idxs], self.y_train[labeled_idxs])
    
    def get_unlabeled_data(self):
        '''return the unlabeled data'''
        unlabeled_idxs = np.arange(self.n_pool)[~self.labeled_idxs]
        return unlabeled_idxs, self.handler(self.X_train[unlabeled_idxs], self.y_train[unlabeled_idxs])
    
    def get_train_data(self):
        '''return the train data'''
        return self.labeled_idxs.copy(), self.handler(self.X_train, self.y_train)
        
    def get_test_data(self):
        '''return the test data'''
        return self.handler(self.X_test, self.y_test)

# dataset handler
class Handler(Dataset):
    '''dataset handler to handle access to the data'''
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __getitem__(self, index):
        x, y = self.X[index], self.Y[index]
        return x, y, index

    def __len__(self):
        return len(self.X)
    

In [3]:
data = pd.DataFrame(np.random.randn(50, 5), columns=['a', 'b', 'c', 'd', 'e'])
features = ['a', 'b', 'c', 'd']
target = 'e'
test_size = 0.2
dataset = Data(data, features, target, random_state, test_size, Handler)
print('n_pool/n_train:', dataset.n_pool)
print('n_test:', dataset.n_test)
print('labeled_idxs:', dataset.labeled_idxs)
print('get_unlabeled_data:', dataset.get_unlabeled_data()[0])
print('get_labeled_data:', dataset.get_labeled_data()[0])
print('get_train_data:', dataset.get_train_data()[0])
print('len of get_train_data:', len(dataset.get_train_data()[0]))

Scaling the data
Splitting the data
n_pool/n_train: 40
n_test: 10
labeled_idxs: [False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False]
get_unlabeled_data: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39]
get_labeled_data: []
get_train_data: [False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False]
len of get_train_data: 40


In [4]:
strategy = RandomSampling(dataset)

In [5]:
# Experiment not yet started
print('**************** Experiment not yet started ****************')
print(f"number of testing pool: {dataset.n_test}")
print(f"number of training pool: {dataset.n_pool}")
print(f'labeled_idxs: {dataset.labeled_idxs}')
print(f'get_labeled_data: {dataset.get_train_data()[0]}')
print(f'pos_idxs: {np.where(dataset.get_train_data()[0] == True)[0]}')

**************** Experiment not yet started ****************
number of testing pool: 10
number of training pool: 40
labeled_idxs: [False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False]
get_labeled_data: [False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False]
pos_idxs: []


In [6]:
n_round = 10
n_query = 3
# start experiment rd = 0
print('**************** Start experiment ****************')
print('++++ Round 0 ++++')
n_init_labeled = 4
dataset.initialize_labels(n_init_labeled) # reshuffle the data, get the labeled_idxs[train_data_idx[:n_init_labeled]]
print(f'labeled_idxs: {dataset.labeled_idxs}')
print(f'get_train_data: {dataset.get_train_data()[0]}')
print(f'pos_idxs: {np.where(dataset.get_train_data()[0] == True)[0]}')
print(f'get_labeled_data: {dataset.get_labeled_data()[0]}')
print(f'get_unlabeled_data: {dataset.get_unlabeled_data()[0]}')

**************** Start experiment ****************
++++ Round 0 ++++
labeled_idxs: [False False False False False False False False False False False False
 False False False  True  True False False False False False False False
  True False False False False False  True False False False False False
 False False False False]
get_train_data: [False False False False False False False False False False False False
 False False False  True  True False False False False False False False
  True False False False False False  True False False False False False
 False False False False]
pos_idxs: [15 16 24 30]
get_labeled_data: [15 16 24 30]
get_unlabeled_data: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 17 18 19 20 21 22 23 25 26
 27 28 29 31 32 33 34 35 36 37 38 39]


In [7]:
for rd in range(1, n_round+1):
    print(f"++++ Round {rd} ++++")
    # get index for new query
    query_idxs = strategy.query(n_query)
    print(f"query_idxs: {query_idxs}")
    print(f'get_train_data: {dataset.get_train_data()[0]}')
    #  change bool dataset.get_train_data()[0] to pos_idxs\
    print(f'pos_idxs: {np.where(dataset.get_train_data()[0] == True)[0]}')
    print(f'get_labeled_data: {dataset.get_labeled_data()[0]}')
    print(f'get_unlabeled_data: {dataset.get_unlabeled_data()[0]}')
    # update labels
    strategy.update(query_idxs)
    print(f'labeled_idxs: {dataset.labeled_idxs}')
    # retrain model
    print('                                                ')

++++ Round 1 ++++
query_idxs: [ 1  5 29]
get_train_data: [False False False False False False False False False False False False
 False False False  True  True False False False False False False False
  True False False False False False  True False False False False False
 False False False False]
pos_idxs: [15 16 24 30]
get_labeled_data: [15 16 24 30]
get_unlabeled_data: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 17 18 19 20 21 22 23 25 26
 27 28 29 31 32 33 34 35 36 37 38 39]
labeled_idxs: [False  True False False False  True False False False False False False
 False False False  True  True False False False False False False False
  True False False False False  True  True False False False False False
 False False False False]
                                                
++++ Round 2 ++++
query_idxs: [ 4 23 39]
get_train_data: [False  True False False False  True False False False False False False
 False False False  True  True False False False False False False False
