In [None]:
import argparse
import label_data, encode_datasets, train_downstream_model
import torch
import pickle
import numpy as np
import os
from os.path import join, exists
from datetime import datetime
import sys
sys.path.append('../keyclass/')
import utils
import models
import create_lfs
import train_classifier

parser_cmd = argparse.ArgumentParser()
parser_cmd.add_argument('--config', default='../config_files/config_imdb.yml', help='Configuration file')
parser_cmd.add_argument('--random_seed', default=0, type=int, help="Random Seed")
args_cmd = parser_cmd.parse_args(args='--config=../config_files/config_imdb.yml --random_seed=0'.split(' '))

## Training

### Encoding Dataset

In [None]:
# print("Encoding Dataset")
# encode_datasets.run(args)

In [None]:
args = utils.Parser(config_file_path=args_cmd.config).parse()

if args['use_custom_encoder']:
    model = models.CustomEncoder(pretrained_model_name_or_path=args['base_encoder'], 
        device='cuda' if torch.cuda.is_available() else 'cpu')
else:
    model = models.Encoder(model_name=args['base_encoder'], 
        device='cuda' if torch.cuda.is_available() else 'cpu')

for split in ['train', 'test']:
    sentences = utils.fetch_data(dataset=args['dataset'], split=split, path=args['data_path'])
    embeddings = model.encode(sentences=sentences, batch_size=args['end_model_batch_size'], 
                                show_progress_bar=args['show_progress_bar'], 
                                normalize_embeddings=args['normalize_embeddings'])
    with open(join(args['data_path'], args['dataset'], f'{split}_embeddings.pkl'), 'wb') as f:
        pickle.dump(embeddings, f)


### Labeling Data

In [None]:
# print("Labeling Data")
# label_data.run(args)

In [None]:
args = utils.Parser(config_file_path=args_cmd.config).parse()
print(args)

# Load training data
train_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='train')

training_labels_present = False
if exists(join(args['data_path'], args['dataset'], 'train_labels.txt')):
    with open(join(args['data_path'], args['dataset'], 'train_labels.txt'), 'r') as f:
        y_train = f.readlines()
    y_train = np.array([int(i.replace('\n','')) for i in y_train])
    training_labels_present = True
else:
    y_train = None
    training_labels_present = False
    print('No training labels found!')

with open(join(args['data_path'], args['dataset'], 'train_embeddings.pkl'), 'rb') as f:
    X_train = pickle.load(f)

# Print dataset statistics
print(f"Getting labels for the {args['dataset']} data...")
print(f'Size of the data: {len(train_text)}')
if training_labels_present:
    print('Class distribution', np.unique(y_train, return_counts=True))

# Load label names/descriptions
label_names = []
for a in args:
    if 'target' in a: label_names.append(args[a])

# Creating labeling functions
labeler = create_lfs.CreateLabellingFunctions(base_encoder=args['base_encoder'], 
                                            device=torch.device(args['device']),
                                            label_model=args['label_model'])
proba_preds = labeler.get_labels(text_corpus=train_text, label_names=label_names, min_df=args['min_df'], 
                                ngram_range=args['ngram_range'], topk=args['topk'], y_train=y_train, 
                                label_model_lr=args['label_model_lr'], label_model_n_epochs=args['label_model_n_epochs'], 
                                verbose=True, n_classes=args['n_classes'])

y_train_pred = np.argmax(proba_preds, axis=1)

# Save the predictions
if not os.path.exists(args['preds_path']): os.makedirs(args['preds_path'])
with open(join(args['preds_path'], f"{args['label_model']}_proba_preds.pkl"), 'wb') as f:
    pickle.dump(proba_preds, f)

# Print statistics
print('Label Model Predictions: Unique value and counts', np.unique(y_train_pred, return_counts=True))
if training_labels_present:
    print('Label Model Training Accuracy', np.mean(y_train_pred==y_train))

    # Log the metrics
    training_metrics_with_gt = utils.compute_metrics(y_preds=y_train_pred, y_true=y_train, average=args['average'])
    utils.log(metrics=training_metrics_with_gt, filename='label_model_with_ground_truth', 
        results_dir=args['results_path'], split='train')


In [None]:
# print("Training Model")
# results = train_downstream_model.train(args)
# print("Model Results:")
# print(results)

### Training Downstream Model

In [None]:
args = utils.Parser(config_file_path=args_cmd.config).parse()

# Set random seeds
random_seed = args_cmd.random_seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)

X_train_embed_masked, y_train_lm_masked, y_train_masked, \
	X_test_embed, y_test, training_labels_present, \
	sample_weights_masked, proba_preds_masked = train_downstream_model.load_data(args)

# Train a downstream classifier

if args['use_custom_encoder']:
	encoder = models.CustomEncoder(pretrained_model_name_or_path=args['base_encoder'], device=args['device'])
else:
	encoder = models.Encoder(model_name=args['base_encoder'], device=args['device'])

classifier = models.FeedForwardFlexible(encoder_model=encoder,
										h_sizes=args['h_sizes'], 
										activation=eval(args['activation']),
										device=torch.device(args['device']))
print('\n===== Training the downstream classifier =====\n')
model = train_classifier.train(model=classifier, 
							device=torch.device(args['device']),
							X_train=X_train_embed_masked, 
							y_train=y_train_lm_masked,
							sample_weights=sample_weights_masked if args['use_noise_aware_loss'] else None, 
							epochs=args['end_model_epochs'], 
							batch_size=args['end_model_batch_size'], 
							criterion=eval(args['criterion']), 
							raw_text=False, 
							lr=eval(args['end_model_lr']), 
							weight_decay=eval(args['end_model_weight_decay']),
							patience=args['end_model_patience'])


end_model_preds_train = model.predict_proba(torch.from_numpy(X_train_embed_masked), batch_size=512, raw_text=False)
end_model_preds_test = model.predict_proba(torch.from_numpy(X_test_embed), batch_size=512, raw_text=False)




### Self-training Model

In [None]:
# Fetching the raw text data for self-training
X_train_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='train')
X_test_text = utils.fetch_data(dataset=args['dataset'], path=args['data_path'], split='test')

model = train_classifier.self_train(model=model, 
									X_train=X_train_text, 
									X_val=X_test_text, 
									y_val=y_test, 
									device=torch.device(args['device']), 
									lr=eval(args['self_train_lr']), 
									weight_decay=eval(args['self_train_weight_decay']),
									patience=args['self_train_patience'], 
									batch_size=args['self_train_batch_size'], 
									q_update_interval=args['q_update_interval'],
									self_train_thresh=eval(args['self_train_thresh']), 
									print_eval=True)


end_model_preds_test = model.predict_proba(X_test_text, batch_size=args['self_train_batch_size'], raw_text=True)


# Print statistics
testing_metrics = utils.compute_metrics_bootstrap(y_preds=np.argmax(end_model_preds_test, axis=1),
													y_true=y_test, 
													average=args['average'], 
													n_bootstrap=args['n_bootstrap'], 
													n_jobs=args['n_jobs'])
print(testing_metrics)

## Testing

In [None]:
results = train_downstream_model.test(args_cmd=args,
                                      end_model_path='../models/imdb/end_model_26-Jul-2022-03_29_41.pth',
                                      end_model_self_trained_path='../models/imdb/end_model_self_trained_26 Jul 2022 03:59:43.pth')
