In [1]:
import sys
sys.path.append('/dccstor/hoo-misha-1/wilds/wilds/examples')
import os

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn

from transformers import DistilBertModel, DistilBertTokenizerFast
from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer
from configs.datasets import dataset_defaults

import wilds
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSPseudolabeledSubset

from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool, get_model_prefix, move_to
from train import train, evaluate, infer_predictions,run_epoch
from algorithms.initializer import initialize_algorithm, infer_d_out
from transforms import initialize_transform

from models.initializer import initialize_model
from configs.utils import populate_defaults
import configs.supported as supported

import torch.multiprocessing

import torchvision.transforms as transforms

from examples.transforms import initialize_bert_transform

from tqdm import tqdm
import argparse
import copy
import re
import psutil
from collections import defaultdict

In [2]:
''' Arg defaults are filled in according to examples/configs/ '''
parser = argparse.ArgumentParser()

# Required arguments
parser.add_argument('-d', '--dataset', choices=wilds.supported_datasets, required=True)
parser.add_argument('--algorithm', required=True, choices=supported.algorithms)
parser.add_argument('--root_dir', required=True,
                    help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')

# Dataset
parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for dataset initialization passed as key1=value1 key2=value2')
parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',
                    help='If true, tries to download the dataset if it does not exist in root_dir.')
parser.add_argument('--frac', type=float, default=1.0,
                    help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.')
parser.add_argument('--version', default=None, type=str, help='WILDS labeled dataset version number.')

# Unlabeled Dataset
parser.add_argument('--unlabeled_split', default=None, type=str, choices=wilds.unlabeled_splits,  help='Unlabeled split to use. Some datasets only have some splits available.')
parser.add_argument('--unlabeled_version', default=None, type=str, help='WILDS unlabeled dataset version number.')
parser.add_argument('--use_unlabeled_y', default=False, type=parse_bool, const=True, nargs='?', 
                    help='If true, unlabeled loaders will also the true labels for the unlabeled data. This is only available for some datasets. Used for "fully-labeled ERM experiments" in the paper. Correct functionality relies on CrossEntropyLoss using ignore_index=-100.')

# Loaders
parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--unlabeled_loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--train_loader', choices=['standard', 'group'])
parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?', help='If true, sample examples such that batches are uniform over groups.')
parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?', help='If true, enforce groups sampled per batch are distinct.')
parser.add_argument('--n_groups_per_batch', type=int)
parser.add_argument('--unlabeled_n_groups_per_batch', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--unlabeled_batch_size', type=int)
parser.add_argument('--eval_loader', choices=['standard'], default='standard')
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Number of batches to process before stepping optimizer and schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).')

# Model
parser.add_argument('--model', choices=supported.models)
parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for model initialization passed as key1=value1 key2=value2')
parser.add_argument('--noisystudent_add_dropout', type=parse_bool, const=True, nargs='?', help='If true, adds a dropout layer to the student model of NoisyStudent.')
parser.add_argument('--noisystudent_dropout_rate', type=float)
parser.add_argument('--pretrained_model_path', default=None, type=str, help='Specify a path to pretrained model weights')
parser.add_argument('--load_featurizer_only', default=False, type=parse_bool, const=True, nargs='?', help='If true, only loads the featurizer weights and not the classifier weights.')

# NoisyStudent-specific loading
parser.add_argument('--teacher_model_path', type=str, help='Path to NoisyStudent teacher model weights. If this is defined, pseudolabels will first be computed for unlabeled data before anything else runs.')

# Transforms
parser.add_argument('--transform', choices=supported.transforms)
parser.add_argument('--additional_train_transform', choices=supported.additional_transforms, help='Optional data augmentations to layer on top of the default transforms.')
parser.add_argument('--target_resolution', nargs='+', type=int, help='The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.')
parser.add_argument('--resize_scale', type=float)
parser.add_argument('--max_token_length', type=int)
parser.add_argument('--randaugment_n', type=int, help='Number of RandAugment transformations to apply.')

# Objective
parser.add_argument('--loss_function', choices=supported.losses)
parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for loss initialization passed as key1=value1 key2=value2')

# Algorithm
parser.add_argument('--groupby_fields', nargs='+')
parser.add_argument('--group_dro_step_size', type=float)
parser.add_argument('--coral_penalty_weight', type=float)
parser.add_argument('--dann_penalty_weight', type=float)
parser.add_argument('--dann_classifier_lr', type=float)
parser.add_argument('--dann_featurizer_lr', type=float)
parser.add_argument('--dann_discriminator_lr', type=float)
parser.add_argument('--afn_penalty_weight', type=float)
parser.add_argument('--safn_delta_r', type=float)
parser.add_argument('--hafn_r', type=float)
parser.add_argument('--use_hafn', default=False, type=parse_bool, const=True, nargs='?')
parser.add_argument('--irm_lambda', type=float)
parser.add_argument('--irm_penalty_anneal_iters', type=int)
parser.add_argument('--self_training_lambda', type=float)
parser.add_argument('--self_training_threshold', type=float)
parser.add_argument('--pseudolabel_T2', type=float, help='Percentage of total iterations at which to end linear scheduling and hold lambda at the max value')
parser.add_argument('--soft_pseudolabels', default=False, type=parse_bool, const=True, nargs='?')
parser.add_argument('--algo_log_metric')
parser.add_argument('--process_pseudolabels_function', choices=supported.process_pseudolabels_functions)

# Model selection
parser.add_argument('--val_metric')
parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')

# Optimization
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--optimizer', choices=supported.optimizers)
parser.add_argument('--lr', type=float)
parser.add_argument('--weight_decay', type=float)
parser.add_argument('--max_grad_norm', type=float)
parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for optimizer initialization passed as key1=value1 key2=value2')

# Scheduler
parser.add_argument('--scheduler', choices=supported.schedulers)
parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for scheduler initialization passed as key1=value1 key2=value2')
parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
parser.add_argument('--scheduler_metric_name')

# Evaluation
parser.add_argument('--process_outputs_function', choices = supported.process_outputs_functions)
parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--eval_splits', nargs='+', default=[])
parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--eval_epoch', default=None, type=int, help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.')

# Misc
parser.add_argument('--device', type=int, nargs='+', default=[0])
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int)
parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_pred', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')
parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False, help='Whether to resume from the most recent saved model in the current log_dir.')

# Weights & Biases
parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--wandb_api_key_path', type=str,
                    help="Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate.")
parser.add_argument('--wandb_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for wandb.init() passed as key1=value1 key2=value2')

ParseKwargs(option_strings=['--wandb_kwargs'], dest='wandb_kwargs', nargs='*', const=None, default={}, type=None, choices=None, help='keyword arguments for wandb.init() passed as key1=value1 key2=value2', metavar=None)

In [26]:
def update_config(parser, dataset, algorithm, model_path):
    global config
    print(f'|   Updating config to use algorithm {algorithm} and pretrained model path {model_path}')
    config = parser.parse_args((f'--dataset {dataset} '
                            f'--algorithm {algorithm} ' 
                            '--root_dir /dccstor/hoo-misha-1/wilds/wilds/data '
                            f'--pretrained_model_path {model_path} '
                            #'--eval_only '
                            #'--model_kwargs ignore_mismatched_sizes=True ' 
                            #'--evaluate_all_splits False '
                            #'--use_wandb '
                            ).split())
    config = populate_defaults(config)
    
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        if len(config.device) > device_count:
            raise ValueError(f"Specified {len(config.device)} devices, but only {device_count} devices found.")

        config.use_data_parallel = len(config.device) > 1
        device_str = ",".join(map(str, config.device))
        os.environ["CUDA_VISIBLE_DEVICES"] = device_str
        config.device = torch.device("cuda")
    else:
        config.use_data_parallel = False
        config.device =torch.device("cpu")

In [27]:
update_config(parser, 'camelyon17', 'ERM', '/dccstor/hoo-misha-1/wilds/wilds/pretrained/camelyon17/camelyon17_ERM.pth')

|   Updating config to use algorithm ERM and pretrained model path /dccstor/hoo-misha-1/wilds/wilds/pretrained/camelyon17/camelyon17_ERM.pth


In [28]:
# Data
full_dataset = wilds.get_dataset(
    dataset=config.dataset,
    version=config.version,
    root_dir=config.root_dir,
    download=config.download,
    split_scheme=config.split_scheme,
    **config.dataset_kwargs)

In [14]:
def update_transform():
    print(f'|   Updating data transforms')
    global config, train_transform, eval_transform
    train_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        additional_transform_name=config.additional_train_transform,
        is_training=True)
    eval_transform = initialize_transform(
        transform_name=config.transform,
        config=config,
        dataset=full_dataset,
        is_training=False)


In [15]:
update_transform()

|   Updating data transforms


In [16]:
train_grouper = CombinatorialGrouper(
    dataset=full_dataset,
    groupby_fields=config.groupby_fields
)

In [17]:
def prune_dataset(dataset,split='test', cutoff = 25):
    dataset_y_array = dataset[split]['dataset'].y_array
    unique_counts = dataset_y_array.unique(return_counts=True)
    prune_classes = unique_counts[0][unique_counts[1] < cutoff]
    prune_ind = []
    for clss in prune_classes:
        prune_ind.append((dataset_y_array == clss).nonzero(as_tuple=True)[0])
    if len(prune_classes > 0):
        prune_ind = torch.concat(prune_ind)
        pruned_ind = torch.ones(dataset_y_array.shape).bool()
        pruned_ind[prune_ind] = False
        dataset[split]['dataset'].indices = dataset[split]['dataset'].indices[prune_ind]

In [18]:
def update_datasets():
    global datasets, datasets_pruned
    # Configure labeled torch datasets (WILDS dataset splits)
    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split=='train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split,
            frac=config.frac,
            transform=transform)
    for split in full_dataset.split_dict.keys():
        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=1,
                **config.loader_kwargs)

        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)
        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose
    
    datasets_pruned = copy.deepcopy(datasets)
    for split in ['val', 'test']:
        prune_dataset(datasets_pruned, split)
    for split in datasets_pruned.keys():
        if split == 'train':
            datasets_pruned[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets_pruned[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=1,
                **config.loader_kwargs)

        else:
            datasets_pruned[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets_pruned[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

In [19]:
update_datasets()

In [38]:
pretrained_model_path = "/dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/"

In [43]:
regex_pattern = re.compile(config.dataset+'_(\w*).pth')

model_file_paths = os.listdir(pretrained_model_path)
algorithm_names = []
for model_path in model_file_paths:
    match = re.match(regex_pattern, model_path)
    print(f'Found pretrained models for algorithms: {match[1]}')
    algorithm_names.append(match[1])

Found pretrained models for algorithms: PseudoLabel
Found pretrained models for algorithms: deepCORAL
Found pretrained models for algorithms: DANN
Found pretrained models for algorithms: AFN
Found pretrained models for algorithms: FixMatch
Found pretrained models for algorithms: ERM


In [84]:
def prune_model():
    global config, datasets, train_grouper, unlabeled_dataset
    print(f'|   Pruning {config.model}')
    print("|   |   ", end = '')
    # Initialize algorithm & load pretrained weights if provided
    algorithm = initialize_algorithm(
        config=config,
        datasets=datasets,
        train_grouper=train_grouper,
    )
    if 'resnet' in config.model:
        algorithm.model = next(algorithm.model.children())
        for param in algorithm.model.parameters():
            param.requires_grad = False
        return algorithm.model
    raise Exception('New model, no pruning done')

In [86]:
for algorithm in algorithm_names:
    if algorithm in ["PseudoLabel"]:
        continue
    model_path = pretrained_model_path + config.dataset + "_" + algorithm + ".pth"
    print(f'Loading {model_path} for algorithm {config.algorithm} using model {config.model}')
    update_config(parser, algorithm, model_path)
    update_transform()
    update_datasets()
    model = prune_model()
    
    for split, dataset in datasets.items():
        print(f'|   Processing {split}')
        loader = dataset['loader']
        features = None
        y = None
        metadata = None
        for X_batch, y_batch , metadata_batch in tqdm(loader):
            #print(psutil.virtual_memory()[2])
            try:
                features_batch = model(X_batch.to(config.device))
                #print(features_batch.shape)
                if type(features_batch) is tuple:
                    #print('is tup')
                    features_batch = features_batch[0].to('cpu').numpy()
                else:
                    features_batch = features_batch.to('cpu').numpy()
                y_batch = y_batch.to('cpu').numpy()
                #print('h')
                if features is None:
                    features = features_batch
                else:
                    features = np.vstack((features, features_batch))
                #print(features_batch.device)
                if y is None:
                    y = y_batch
                else:
                    y = np.concatenate((y, y_batch))
                if metadata is None:
                    metadata = metadata_batch
                else:
                    metadata = np.vstack((metadata,metadata_batch))
            except Exception as e:
                print(f'|   |   Caught exception {e}')
        #features = torch.vstack(features).numpy()
        #y = torch.concat(y).numpy()
        #metadata = np.vstack(metadata)
        print(f'|   |   Features has shape {features.shape} and labels has shape {y.shape}')
        save_path_base = f'/dccstor/hoo-misha-1/wilds/wilds/features/{config.dataset}/{config.algorithm}'
        if not os.path.exists(save_path_base):
            os.makedirs(save_path_base)
        features_save_path = f'{save_path_base}/{config.model}_{split}_features.npy'
        labels_save_path = f'{save_path_base}/{config.model}_{split}_labels.npy'
        metadata_save_path = f'{save_path_base}/{config.model}_{split}_metadata.npy'
        np.save(features_save_path, features)
        np.save(labels_save_path, y)
        np.save(metadata_save_path, metadata)
        print(f'|   |   |   Features saved to {features_save_path}')
        print(f'|   |   |   Labels saved to {labels_save_path}')
        print(f'|   |   |    Metadata saved to {metadata_save_path}')
    print(f'Completed features and labels for {config.algorithm}')
    

Loading /dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/iwildcam_deepCORAL.pth for algorithm deepCORAL using model resnet50
|   Updating config to use algorithm deepCORAL and pretrained model path /dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/iwildcam_deepCORAL.pth
|   Updating data transforms
|   Pruning resnet50
|   |   Initialized model with pretrained weights from /dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/iwildcam_deepCORAL.pth previously trained for 2 epochs with previous val metric 0.2722762855576651 
|   Processing train


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8113/8113 [37:04<00:00,  3.65it/s]


|   |   Features has shape (129808, 2048) and labels has shape (129808,)
|   |   |   Features saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_train_features.npy
|   |   |   Labels saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_train_labels.npy
|   |   |    Metadata saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_train_metadata.npy
|   Processing val


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 936/936 [01:28<00:00, 10.58it/s]


|   |   Features has shape (14961, 2048) and labels has shape (14961,)
|   |   |   Features saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_val_features.npy
|   |   |   Labels saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_val_labels.npy
|   |   |    Metadata saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_val_metadata.npy
|   Processing test


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2675/2675 [06:11<00:00,  7.21it/s]


|   |   Features has shape (42791, 2048) and labels has shape (42791,)
|   |   |   Features saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_test_features.npy
|   |   |   Labels saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_test_labels.npy
|   |   |    Metadata saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_test_metadata.npy
|   Processing id_val


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 458/458 [00:37<00:00, 12.30it/s]


|   |   Features has shape (7314, 2048) and labels has shape (7314,)
|   |   |   Features saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_id_val_features.npy
|   |   |   Labels saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_id_val_labels.npy
|   |   |    Metadata saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_id_val_metadata.npy
|   Processing id_test


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 510/510 [00:42<00:00, 12.06it/s]


|   |   Features has shape (8154, 2048) and labels has shape (8154,)
|   |   |   Features saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_id_test_features.npy
|   |   |   Labels saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_id_test_labels.npy
|   |   |    Metadata saved to /dccstor/hoo-misha-1/wilds/wilds/features/iwildcam/deepCORAL/resnet50_id_test_metadata.npy
Completed features and labels for deepCORAL
Loading /dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/iwildcam_DANN.pth for algorithm deepCORAL using model resnet50
|   Updating config to use algorithm DANN and pretrained model path /dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/iwildcam_DANN.pth
|   Updating data transforms
|   Pruning resnet50
|   |   Initialized model with pretrained weights from /dccstor/hoo-misha-1/wilds/wilds/pretrained/iwildcam/iwildcam_DANN.pth previously trained for 1 epochs with previous val metric 0.36872491313728456 
|   Proces

  5%|███████████▍                                                                                                                                                                                                          | 434/8113 [00:35<10:20, 12.38it/s]


KeyboardInterrupt: 

In [None]:
print('All algorithms featurized')