In [None]:
import numpy as np
import gzip
import shutil
import random
import os
import numpy as np
import vowpalwabbit as vw
import gc
import re
import argparse
from tqdm import tqdm

In [None]:
def read_data(filename):
    if filename.endswith('.vw.gz'):
        filename = uncompress(filename)
    with open(filename,'r') as f:
        print(f'opening {filename}')
        lines = f.readlines()
        lines = [line.rstrip() for line in lines]
    return lines

def uncompress(filename):
    names = filename.split('.')
    vw_name = names[0] + '.' + names[1]
    with gzip.open(filename, 'rb') as f_in:
        with open(vw_name, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    return vw_name

def process_actions(sup_data):
    actions = []
    for data in sup_data:
        action = data.split('|')[0]
        actions.append(action)
    action_list = list(set(actions))
    action_size = len(action_list)
    print('actions: ', action_list)

    act_cnt =[0]*action_size
    for a in actions:
        act_cnt[action_list.index(a)] += 1
    assert(len(actions)==np.sum(act_cnt))
    best_const_act_acc = (max(act_cnt)/1.0*len(actions))
    print('action count:', act_cnt)

    action_maj = np.zeros(action_size)
    if np.max(act_cnt) / np.sum(act_cnt) > 0.5:
        action_maj[np.argmax(act_cnt)] = 1

    return action_size, action_list, act_cnt, action_maj, best_const_act_acc

def create_train_test_split(data):
    sup_data_len = len(data)
    random.shuffle(data)
    train_data = data[:int(0.9 * sup_data_len)]
    test_data = data[int(0.9 * sup_data_len):]
    return train_data, test_data

In [None]:

# from utils import *

class IGL_xCI():
    def __init__(self, did=554, epochs=10,verbose=False,amlt=True) -> None:
        self.epochs = epochs
        if amlt:
            self.data_dir = '/mnt/data/openml/'
        else:
#             self.data_dir = os.getcwd()
#             self.data_dir += '/data/openml/'
            self.data_dir = './data/'
        self.verbose = verbose
        self.unbalanced = 0

        vw_file_prefix = 'ds_'+str(did)+'_'
        for f in os.listdir(self.data_dir):
            if f.startswith(vw_file_prefix):
                vw_file_name = f
                break
        
        try:
            self.data = read_data(self.data_dir+vw_file_name)
            self.action_size, self.action_list, self.action_count, self.action_maj, best_const_action_acc = process_actions(self.data)
            self.const_act_acc = best_const_action_acc
            if sum(self.action_maj)>0:
                self.unbalanced =1
            self.train_data, self.test_data = create_train_test_split(self.data)
            self.action_flip_list = np.zeros(self.action_size)
            self.num_samples = len(self.data)
            train_samples = len(self.train_data)

            print(f'total epochs: {self.epochs}\n')
        except:
            print('Error in reading dataset!')
            exit(1)

        self.dummi_f = []
        self.dummi_psi = []

        for a_idx in range(self.action_size):
            vw_f = vw.Workspace(f'--max_prediction 1 --min_prediction -1', \
                                quiet=False, enable_logging=True)
            self.dummi_f.append(vw_f)
            
            vw_psi = vw.Workspace(f'--min_prediction -1 --max_prediction 1', quiet=True)
            self.dummi_psi.append(vw_psi)


    def collect_data(self, sup_data):
        data = [[] for _ in range(self.action_size)]
        for idx in range(len(sup_data)):
            x = sup_data[idx].split('|')[1]
            x_label = sup_data[idx].split('|')[0]
            a_idx = np.random.randint(self.action_size)
            rwd = int(int(self.action_list[a_idx]) == int(x_label))     
            y = str(1 + a_idx)
            if rwd:
                y += ' ' + str(1 + self.action_size)  
            data[a_idx].append([x,y])
        return data


    def train(self, igl_data, sup_data):
        for epoch in range(self.epochs):
            if self.verbose:
                print('\nepoch: ', epoch)
            for a_idx in range(self.action_size):
                
                a_data = igl_data[a_idx]
                dummi_f_a = self.dummi_f[a_idx]
                dummi_psi_a = self.dummi_psi[a_idx]
            
                for n,sample in enumerate(a_data):
                    x,y = sample
                    x_ind, y_ind = random.choice(a_data)
                    f_x = dummi_f_a.predict(' | ' + x)
                    psi_y = dummi_psi_a.predict(' | ' + y)
                    f_x_ind = dummi_f_a.predict(' | ' + x_ind)
                    psi_y_ind = dummi_psi_a.predict(' | ' + y_ind)
                    
                    # VW reduction for the first term of contrastive loss
                    # sample for (y,f(' | ' + x))
                    psi_eg_1 = str(f_x) + ' | ' + y
                    # sample for (x,psi(' | ' + y))
                    f_eg_1 = str(psi_y) + ' | ' + x
                    
                    # VW reduction for the second term of contrastive loss
                    # sample for (y,1) if f(x) < psi(y_ind)
                    imp_psi_eg_2 = ' '  + str(np.fabs(psi_y_ind - f_x))
                    if psi_y_ind < f_x:
                        psi_eg_2 = str(-1) + imp_psi_eg_2
                    elif psi_y_ind > f_x:
                        psi_eg_2 = str(1) + imp_psi_eg_2
                    else:
                        r = np.random.choice([1,-1])
                        psi_eg_2 = str(r)
                    psi_eg_2 +=  ' | ' + y_ind
                    # sample for (x,1) if f(x) > psi(y_ind)
                    imp_f_eg_2 = ' ' + str(np.fabs(psi_y - f_x_ind))
                    if f_x_ind < psi_y:
                        f_eg_2 = str(-1) + imp_f_eg_2
                    elif f_x_ind > psi_y:
                        f_eg_2 = str(1) + imp_f_eg_2
                    else:
                        r = np.random.choice([1,-1])
                        f_eg_2 = str(r)
                    f_eg_2 += ' | ' + x_ind

                    # in the first half of epochs, we learn reward decoder dummi_psi_a using the contrastive loss
                    # in the second half of epochs, we use the decoded reward from dummi_psi_a to label data, and then use DM approach in CB to learn the expected reward
                    # note that, we also use the value of dummi_f_a from the first half of epochs as the initialization of the learned expected reward
                    # train f and psi per example
                    dummi_f_a.learn(f_eg_1)
                    if epoch < self.epochs / 2:
                        dummi_f_a.learn(f_eg_2)
                    
                        dummi_psi_a.learn(psi_eg_1)
                        dummi_psi_a.learn(psi_eg_2)  
                
                if epoch < self.epochs / 2:
                    action_sum = 0
                    action_rwd = 0
                    for n,sample in enumerate(a_data):
                        x,y = sample
                        psi_y = dummi_psi_a.predict(' | ' + y)
                        action_sum +=1
                        action_rwd +=psi_y

                    if self.action_maj[a_idx]:
                        if action_rwd/action_sum<0:
                            self.action_flip_list[a_idx] = 1
                    else:
                        if action_rwd/action_sum>0:
                            self.action_flip_list[a_idx] = 1
            
            # per_action_acc.append(accuracy)
            if self.verbose:
                accuracy = self.evaluate(sup_data)
                print('f',accuracy)

        # return per_action_acc


    def f_prediction(self, x):
        a_score = []
        for a_idx in range(self.action_size):
            if self.action_flip_list[a_idx]:
                a_score.append(0 - self.dummi_f[a_idx].predict(' | ' + x))
            else:
                a_score.append(self.dummi_f[a_idx].predict(' | ' + x))
        return self.action_list[np.argmax(a_score)]



    def evaluate(self, sup_data):
        total_num = 0
        total_rwd = 0
        
        action_picked_dict = {}
        act_cnt = [0] * self.action_size
        
        for idx in range(len(sup_data)):
            x = ' | ' + sup_data[idx].split('|')[1]
            x_label = sup_data[idx].split('|')[0]
            predict_label = self.f_prediction(x)
            rwd = int(int(predict_label) == int(x_label))
            total_num += 1
            total_rwd += rwd
            
            if predict_label not in action_picked_dict:
                action_picked_dict[predict_label] = 1
            else:
                action_picked_dict[predict_label] += 1
                
        print('chosen actions: ', action_picked_dict)
        return(total_rwd / total_num)

    def run(self):
        igl_data = self.collect_data(self.train_data)
        self.train(igl_data, self.test_data)
        accuracy = self.evaluate(self.test_data)
        return accuracy

In [None]:
def main():

    agent = IGL_xCI(did=116,verbose=True,amlt=False,epochs=10)
    # agent = IGL_xCI(did=1568,verbose=True,amlt=False,epochs=10)
    # agent = IGL_xCI(did=554,verbose=True,amlt=False,epochs=10) # MNIST
    
    igl_data = agent.collect_data(agent.train_data)
    accuracy = agent.evaluate(agent.test_data)
    print('f: ',accuracy)

    agent.train(igl_data, agent.test_data)
    accuracy = agent.evaluate(agent.test_data)
    print('f: ', accuracy)

In [None]:
main()