In [1]:
import sys
sys.path.append('/dccstor/hoo-misha-1/wilds/wilds/examples')
sys.path.append('/dccstor/hoo-misha-1/wilds/WOODS')
sys.path.append('/dccstor/hoo-misha-1/wilds/wilds')

import os

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn

from transformers import DistilBertModel, DistilBertTokenizerFast
from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer
from configs.datasets import dataset_defaults

import wilds
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSPseudolabeledSubset

from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool, get_model_prefix, move_to
from train import train, evaluate, infer_predictions,run_epoch
from algorithms.initializer import initialize_algorithm, infer_d_out
from transforms import initialize_transform

from models.initializer import initialize_model
from configs.utils import populate_defaults
import configs.supported as supported

import torch.multiprocessing

import torchvision.transforms as transforms

from examples.transforms import initialize_bert_transform

from tqdm import tqdm
import argparse
import copy
import re
import psutil
from collections import defaultdict

# Initialize Wilds Config

In [2]:
''' Arg defaults are filled in according to examples/configs/ '''
parser = argparse.ArgumentParser()

# Required arguments
parser.add_argument('-d', '--dataset', choices=wilds.supported_datasets, required=True)
parser.add_argument('--algorithm', required=True, choices=supported.algorithms)
parser.add_argument('--root_dir', required=True,
                    help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')

# Dataset
parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for dataset initialization passed as key1=value1 key2=value2')
parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',
                    help='If true, tries to download the dataset if it does not exist in root_dir.')
parser.add_argument('--frac', type=float, default=1.0,
                    help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.')
parser.add_argument('--version', default=None, type=str, help='WILDS labeled dataset version number.')

# Unlabeled Dataset
parser.add_argument('--unlabeled_split', default=None, type=str, choices=wilds.unlabeled_splits,  help='Unlabeled split to use. Some datasets only have some splits available.')
parser.add_argument('--unlabeled_version', default=None, type=str, help='WILDS unlabeled dataset version number.')
parser.add_argument('--use_unlabeled_y', default=False, type=parse_bool, const=True, nargs='?', 
                    help='If true, unlabeled loaders will also the true labels for the unlabeled data. This is only available for some datasets. Used for "fully-labeled ERM experiments" in the paper. Correct functionality relies on CrossEntropyLoss using ignore_index=-100.')

# Loaders
parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--unlabeled_loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--train_loader', choices=['standard', 'group'])
parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?', help='If true, sample examples such that batches are uniform over groups.')
parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?', help='If true, enforce groups sampled per batch are distinct.')
parser.add_argument('--n_groups_per_batch', type=int)
parser.add_argument('--unlabeled_n_groups_per_batch', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--unlabeled_batch_size', type=int)
parser.add_argument('--eval_loader', choices=['standard'], default='standard')
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Number of batches to process before stepping optimizer and schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).')

# Model
parser.add_argument('--model', choices=supported.models)
parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for model initialization passed as key1=value1 key2=value2')
parser.add_argument('--noisystudent_add_dropout', type=parse_bool, const=True, nargs='?', help='If true, adds a dropout layer to the student model of NoisyStudent.')
parser.add_argument('--noisystudent_dropout_rate', type=float)
parser.add_argument('--pretrained_model_path', default=None, type=str, help='Specify a path to pretrained model weights')
parser.add_argument('--load_featurizer_only', default=False, type=parse_bool, const=True, nargs='?', help='If true, only loads the featurizer weights and not the classifier weights.')

# NoisyStudent-specific loading
parser.add_argument('--teacher_model_path', type=str, help='Path to NoisyStudent teacher model weights. If this is defined, pseudolabels will first be computed for unlabeled data before anything else runs.')

# Transforms
parser.add_argument('--transform', choices=supported.transforms)
parser.add_argument('--additional_train_transform', choices=supported.additional_transforms, help='Optional data augmentations to layer on top of the default transforms.')
parser.add_argument('--target_resolution', nargs='+', type=int, help='The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.')
parser.add_argument('--resize_scale', type=float)
parser.add_argument('--max_token_length', type=int)
parser.add_argument('--randaugment_n', type=int, help='Number of RandAugment transformations to apply.')

# Objective
parser.add_argument('--loss_function', choices=supported.losses)
parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for loss initialization passed as key1=value1 key2=value2')

# Algorithm
parser.add_argument('--groupby_fields', nargs='+')
parser.add_argument('--group_dro_step_size', type=float)
parser.add_argument('--coral_penalty_weight', type=float)
parser.add_argument('--wasserstein_blur', type=float, default=0.0001)
parser.add_argument('--dann_penalty_weight', type=float)
parser.add_argument('--dann_classifier_lr', type=float)
parser.add_argument('--dann_featurizer_lr', type=float)
parser.add_argument('--dann_discriminator_lr', type=float)
parser.add_argument('--afn_penalty_weight', type=float)
parser.add_argument('--safn_delta_r', type=float)
parser.add_argument('--hafn_r', type=float)
parser.add_argument('--use_hafn', default=False, type=parse_bool, const=True, nargs='?')
parser.add_argument('--irm_lambda', type=float)
parser.add_argument('--irm_penalty_anneal_iters', type=int)
parser.add_argument('--self_training_lambda', type=float)
parser.add_argument('--self_training_threshold', type=float)
parser.add_argument('--pseudolabel_T2', type=float, help='Percentage of total iterations at which to end linear scheduling and hold lambda at the max value')
parser.add_argument('--soft_pseudolabels', default=False, type=parse_bool, const=True, nargs='?')
parser.add_argument('--algo_log_metric')
parser.add_argument('--process_pseudolabels_function', choices=supported.process_pseudolabels_functions)

# Model selection
parser.add_argument('--val_metric')
parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')

# Optimization
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--optimizer', choices=supported.optimizers)
parser.add_argument('--lr', type=float)
parser.add_argument('--weight_decay', type=float)
parser.add_argument('--max_grad_norm', type=float)
parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for optimizer initialization passed as key1=value1 key2=value2')

# Scheduler
parser.add_argument('--scheduler', choices=supported.schedulers)
parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for scheduler initialization passed as key1=value1 key2=value2')
parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
parser.add_argument('--scheduler_metric_name')

# Evaluation
parser.add_argument('--process_outputs_function', choices = supported.process_outputs_functions)
parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--eval_splits', nargs='+', default=[])
parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--eval_epoch', default=None, type=int, help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.')

# Misc
parser.add_argument('--device', type=int, nargs='+', default=[0])
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int)
parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_pred', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')
parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False, help='Whether to resume from the most recent saved model in the current log_dir.')

# Weights & Biases
parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--wandb_api_key_path', type=str,
                    help="Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate.")
parser.add_argument('--wandb_kwargs', nargs='*', action=ParseKwargs, default={},
                    help='keyword arguments for wandb.init() passed as key1=value1 key2=value2')

# BREEDS
parser.add_argument('--breeds', type=parse_bool, const=True, nargs='?', default=False) 

_StoreAction(option_strings=['--breeds'], dest='breeds', nargs='?', const=True, default=False, type=<function parse_bool at 0x152f4b6be9d0>, choices=None, help=None, metavar=None)

In [3]:
def update_config(parser, dataset, algorithm):
    global config, logger, mode
    print(f'|   Updating config to use algorithm {algorithm}')
    config = parser.parse_args((f'--dataset {dataset} '
                            f'--algorithm {algorithm} ' 
                            '--root_dir /dccstor/hoo-misha-1/wilds/wilds/data '
                            f'--wasserstein_blur 0.000001 '
                            f'--coral_penalty_weight 0.0001 '
                            f'--log_dir /dccstor/hoo-misha-1/wilds/WOODS/logs/cifar100/{algorithm} '
                            f'--breeds True '
                            f'--evaluate_all_splits False '
                            #'--eval_only '
                            #'--model_kwargs ignore_mismatched_sizes=True ' 
                            #'--evaluate_all_splits False '
                            #'--use_wandb '
                            ).split())
    config = populate_defaults(config)
    
    if torch.cuda.is_available():
        device_count = torch.cuda.device_count()
        if len(config.device) > device_count:
            raise ValueError(f"Specified {len(config.device)} devices, but only {device_count} devices found.")

        config.use_data_parallel = len(config.device) > 1
        device_str = ",".join(map(str, config.device))
        os.environ["CUDA_VISIBLE_DEVICES"] = device_str
        config.device = torch.device("cuda")
    else:
        config.use_data_parallel = False
        config.device =torch.device("cpu")
    
    # Initialize logs
    if os.path.exists(config.log_dir) and config.resume:
        resume=True
        mode='a'
    elif os.path.exists(config.log_dir) and config.eval_only:
        resume=False
        mode='a'
    else:
        resume=False
        mode='w'

    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)
    return logger

In [4]:
update_config(parser, 'iwildcam', 'wassersteindeepCORAL')

|   Updating config to use algorithm wassersteindeepCORAL


<utils.Logger at 0x152f48f3a5e0>

In [5]:
# Data
full_dataset = wilds.get_dataset(
    dataset=config.dataset,
    version=config.version,
    root_dir=config.root_dir,
    download=config.download,
    split_scheme=config.split_scheme,
    **config.dataset_kwargs)

# Transforms & data augmentations for labeled dataset
# To modify data augmentation, modify the following code block.
# If you want to use transforms that modify both `x` and `y`,
# set `do_transform_y` to True when initializing the `WILDSSubset` below.
train_transform = initialize_transform(
    transform_name=config.transform,
    config=config,
    dataset=full_dataset,
    additional_transform_name=config.additional_train_transform,
    is_training=True)
eval_transform = initialize_transform(
    transform_name=config.transform,
    config=config,
    dataset=full_dataset,
    is_training=False)

# Configure unlabeled datasets
unlabeled_dataset = None
if config.unlabeled_split is not None:
    split = config.unlabeled_split
    full_unlabeled_dataset = wilds.get_dataset(
        dataset=config.dataset,
        version=config.unlabeled_version,
        root_dir=config.root_dir,
        download=config.download,
        unlabeled=True,
        **config.dataset_kwargs
    )
    train_grouper = CombinatorialGrouper(
        dataset=[full_dataset, full_unlabeled_dataset],
        groupby_fields=config.groupby_fields
    )

    # Transforms & data augmentations for unlabeled dataset
    if config.algorithm == "FixMatch":
        # For FixMatch, we need our loader to return batches in the form ((x_weak, x_strong), m)
        # We do this by initializing a special transform function
        unlabeled_train_transform = initialize_transform(
            config.transform, config, full_dataset, is_training=True, additional_transform_name="fixmatch"
        )
    else:
        # Otherwise, use the same data augmentations as the labeled data.
        unlabeled_train_transform = train_transform

    if config.algorithm == "NoisyStudent":
        # For Noisy Student, we need to first generate pseudolabels using the teacher
        # and then prep the unlabeled dataset to return these pseudolabels in __getitem__
        print("Inferring teacher pseudolabels for Noisy Student")
        assert config.teacher_model_path is not None
        if not config.teacher_model_path.endswith(".pth"):
            # Use the best model
            config.teacher_model_path = os.path.join(
                config.teacher_model_path,  f"{config.dataset}_seed:{config.seed}_epoch:best_model.pth"
            )

        d_out = infer_d_out(full_dataset, config)
        teacher_model = initialize_model(config, d_out).to(config.device)
        load(teacher_model, config.teacher_model_path, device=config.device)
        # Infer teacher outputs on weakly augmented unlabeled examples in sequential order
        weak_transform = initialize_transform(
            transform_name=config.transform,
            config=config,
            dataset=full_dataset,
            is_training=True,
            additional_transform_name="weak"
        )
        unlabeled_split_dataset = full_unlabeled_dataset.get_subset(split, transform=weak_transform, frac=config.frac)
        sequential_loader = get_eval_loader(
            loader=config.eval_loader,
            dataset=unlabeled_split_dataset,
            grouper=train_grouper,
            batch_size=config.unlabeled_batch_size,
            **config.unlabeled_loader_kwargs
        )
        teacher_outputs = infer_predictions(teacher_model, sequential_loader, config)
        teacher_outputs = move_to(teacher_outputs, torch.device("cpu"))
        unlabeled_split_dataset = WILDSPseudolabeledSubset(
            reference_subset=unlabeled_split_dataset,
            pseudolabels=teacher_outputs,
            transform=unlabeled_train_transform,
            collate=full_dataset.collate,
        )
        teacher_model = teacher_model.to(torch.device("cpu"))
        del teacher_model
    else:
        unlabeled_split_dataset = full_unlabeled_dataset.get_subset(
            split, 
            transform=unlabeled_train_transform, 
            frac=config.frac, 
            load_y=config.use_unlabeled_y
        )

    unlabeled_dataset = {
        'split': split,
        'name': full_unlabeled_dataset.split_names[split],
        'dataset': unlabeled_split_dataset
    }
    unlabeled_dataset['loader'] = get_train_loader(
        loader=config.train_loader,
        dataset=unlabeled_dataset['dataset'],
        batch_size=config.unlabeled_batch_size,
        uniform_over_groups=config.uniform_over_groups,
        grouper=train_grouper,
        distinct_groups=config.distinct_groups,
        n_groups_per_batch=config.unlabeled_n_groups_per_batch,
        **config.unlabeled_loader_kwargs
    )
else:
    train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=config.groupby_fields
    )

# Configure labeled torch datasets (WILDS dataset splits)
wilds_datasets = defaultdict(dict)
for split in full_dataset.split_dict.keys():
    if split=='train':
        transform = train_transform
        verbose = True
    elif split == 'val':
        transform = eval_transform
        verbose = True
    else:
        transform = eval_transform
        verbose = False
    # Get subset
    wilds_datasets[split]['dataset'] = full_dataset.get_subset(
        split,
        frac=config.frac,
        transform=transform)

    if split == 'train':
        wilds_datasets[split]['loader'] = get_train_loader(
            loader=config.train_loader,
            dataset=wilds_datasets[split]['dataset'],
            batch_size=config.batch_size,
            uniform_over_groups=config.uniform_over_groups,
            grouper=train_grouper,
            distinct_groups=config.distinct_groups,
            n_groups_per_batch=config.n_groups_per_batch,
            **config.loader_kwargs)
    else:
        wilds_datasets[split]['loader'] = get_eval_loader(
            loader=config.eval_loader,
            dataset=wilds_datasets[split]['dataset'],
            grouper=train_grouper,
            batch_size=config.batch_size,
            **config.loader_kwargs)

    # Set fields
    wilds_datasets[split]['split'] = split
    wilds_datasets[split]['name'] = full_dataset.split_names[split]
    wilds_datasets[split]['verbose'] = verbose

    # Loggers
    wilds_datasets[split]['eval_logger'] = BatchLogger(
        os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=config.use_wandb
    )
    wilds_datasets[split]['algo_logger'] = BatchLogger(
        os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=config.use_wandb
    )


In [6]:
wilds_datasets['train']['dataset']._n_classes = 20

In [7]:
# Initialize algorithm & load pretrained weights if provided
algorithm = initialize_algorithm(
    config=config,
    datasets=wilds_datasets,
    train_grouper=train_grouper,
    unlabeled_dataset=unlabeled_dataset,
)

In [8]:
from torchvision.datasets import CIFAR100
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as tt

from scripts.datasets import CustomCIFAR100

from PIL import Image as im

In [9]:
stats = ((0.5074,0.4867,0.4411),(0.2011,0.1987,0.2025))
train_transform = tt.Compose([
    tt.RandomHorizontalFlip(),
    tt.RandomCrop(32,padding=4,padding_mode="reflect"),
    tt.ToTensor(),
    tt.Normalize(*stats)
])

In [10]:
dataset = CIFAR100(root='/dccstor/hoo-misha-1/wilds/WOODS/data/', download=True, transform=train_transform)

Files already downloaded and verified


In [11]:
mapping_coarse_fine = {
    'aquatic mammals': ['beaver', 'dolphin', 'otter', 'seal', 'whale'],
    'fish': ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
    'flowers': ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
    'food containers': ['bottle', 'bowl', 'can', 'cup', 'plate'],
    'fruit and vegetables': ['apple', 'mushroom', 'orange', 'pear',
                             'sweet_pepper'],
    'household electrical device': ['clock', 'keyboard', 'lamp',
                                    'telephone', 'television'],
    'household furniture': ['bed', 'chair', 'couch', 'table', 'wardrobe'],
    'insects': ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
    'large carnivores': ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
    'large man-made outdoor things': ['bridge', 'castle', 'house', 'road',
                                      'skyscraper'],
    'large natural outdoor scenes': ['cloud', 'forest', 'mountain', 'plain',
                                     'sea'],
    'large omnivores and herbivores': ['camel', 'cattle', 'chimpanzee',
                                       'elephant', 'kangaroo'],
    'medium-sized mammals': ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
    'non-insect invertebrates': ['crab', 'lobster', 'snail', 'spider', 'worm'],
    'people': ['baby', 'boy', 'girl', 'man', 'woman'],
    'reptiles': ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
    'small mammals': ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
    'trees': ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree',
              'willow_tree'],
    'vehicles 1': ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
    'vehicles 2': ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor'],
}

mapping_coarse_idx = {
    'aquatic mammals': 0,
    'fish': 1,
    'flowers': 2,
    'food containers': 3,
    'fruit and vegetables': 4,
    'household electrical device': 5,
    'household furniture': 6,
    'insects': 7,
    'large carnivores': 8,
    'large man-made outdoor things': 9,
    'large natural outdoor scenes': 10,
    'large omnivores and herbivores': 11,
    'medium-sized mammals': 12,
    'non-insect invertebrates': 13,
    'people': 14,
    'reptiles': 15,
    'small mammals': 16,
    'trees': 17,
    'vehicles 1': 18,
    'vehicles 2': 19,
}

In [12]:
out_dim = len(mapping_coarse_fine)

In [13]:
domain_dim = 5

In [14]:
mapping_fine_domain = {}
mapping_fine_coarse = {}
mapping_domain_fine = {}
mapping_idx_domain = [-1 for i in dataset.classes]
for coarse_label in mapping_coarse_fine:
    fine_labels = mapping_coarse_fine[coarse_label]
    domain = 0
    for fine_label in fine_labels:
        mapping_fine_coarse[fine_label] = coarse_label
        idx = dataset.classes.index(fine_label)
        mapping_idx_domain[idx] = domain
        mapping_fine_domain[fine_label] = domain
        if domain in mapping_domain_fine:
            mapping_domain_fine[domain].append(fine_label)
        else:
            mapping_domain_fine[domain] = [fine_label]
        domain += 1

In [15]:
coarse_targets = np.array([mapping_coarse_idx[mapping_fine_coarse[dataset.classes[idx]]] for idx in dataset.targets])

In [16]:
metadata = np.array([mapping_idx_domain[l] for l in dataset.targets])[:,np.newaxis]

In [17]:
metadata.shape

(50000, 1)

In [18]:
target_domain_idx = metadata < 3

In [19]:
target_domain_idx = np.where(target_domain_idx)[0]

In [20]:
cifar100_dataset = CustomCIFAR100(dataset.data[target_domain_idx], coarse_targets[target_domain_idx], metadata[target_domain_idx], dataset.transform)

In [21]:
BATCH_SIZE=128
train_dl = DataLoader(cifar100_dataset,BATCH_SIZE,num_workers=4,pin_memory=True,shuffle=True)

# Replace WILDS dataset with BREEDS dataset

In [22]:
wilds_datasets['train']['loader'] = train_dl

In [23]:
wilds_datasets['train']['dataset']._n_classes = 20

In [24]:
wilds_datasets['val']['loader'] = train_dl

In [25]:
wilds_datasets['val']['dataset']._n_classes = 20

In [26]:
# from robustness.model_utils import make_and_restore_model
# model, _ = make_and_restore_model(arch='resnet50', dataset =dataset_source)
# model = model.model
# model.needs_y = False

In [27]:
#algorithm.model = model

In [28]:
# deepCORAL
if config.algorithm == 'deepCORAL':
    algorithm.model = nn.Sequential(next(algorithm.model.children()), nn.Linear(2048,out_dim).to('cuda'))
# wassersteindeepCORAL
if config.algorithm == 'wassersteindeepCORAL':
    algorithm.model = nn.Sequential(next(algorithm.model.children()), nn.Linear(2048,out_dim).to('cuda'))
# ERM
elif config.algorithm == 'ERM': 
    algorithm.model.fc = nn.Linear(2048, out_dim).to('cuda')
# DANN
elif config.algorithm == 'DANN':
    algorithm.model.classifier = nn.Linear(2048,out_dim).to('cuda')
    new_domain_classifier = []
    count = 0
    children = iter(algorithm.model.domain_classifier.children())
    while count != 6:
        new_domain_classifier.append(next(children))
        count += 1
    new_domain_classifier.append(nn.Linear(1024, domain_dim))
    algorithm.model.domain_classifier = nn.Sequential(*new_domain_classifier).to('cuda')

In [29]:
# new_domain_classifier = []
# count = 0
# children = iter(algorithm.model.domain_classifier.children())
# while count != 6:
#     new_domain_classifier.append(next(children))
#     count += 1
# # new_domain_classifier.append(nn.Linear(1024, 10))
# algorithm.model.domain_classifier = nn.Sequential(*new_domain_classifier).to('cuda')

In [30]:
#algorithm.no_group_logging = True

In [31]:
#algorithm.loss = nn.functional.cross_entropy

In [None]:
train(algorithm, wilds_datasets, logger, config, 0, None)


Epoch [0]:

Train:
objective: 2.921
penalty: 753.734
loss_avg: 2.845
acc_avg: 0.140

objective: 2.476
penalty: 724.743
loss_avg: 2.404
acc_avg: 0.307

objective: 2.116
penalty: 700.258
loss_avg: 2.046
acc_avg: 0.393

objective: 1.872
penalty: 661.365
loss_avg: 1.806
acc_avg: 0.447

objective: 1.709
penalty: 612.962
loss_avg: 1.647
acc_avg: 0.484

Epoch eval:
Average acc: 0.345
Recall macro: 0.345
F1 macro: 0.340

Validation (OOD/Trans):
objective: 1.484
penalty: 0.000
loss_avg: 1.484
acc_avg: 0.542

Epoch eval:
Average acc: 0.542
Recall macro: 0.542
F1 macro: 0.536
Validation F1-macro_all: 0.536
Epoch 0 has the best validation performance so far.


Epoch [1]:

Train:
objective: 1.569
penalty: 592.958
loss_avg: 1.509
acc_avg: 0.527

objective: 1.449
penalty: 559.500
loss_avg: 1.393
acc_avg: 0.565

objective: 1.394
penalty: 509.336
loss_avg: 1.343
acc_avg: 0.579

objective: 1.321
penalty: 463.831
loss_avg: 1.274
acc_avg: 0.601

objective: 1.313
penalty: 432.770
loss_avg: 1.270
acc_avg: 

In [None]:
algorithm.model