In [20]:
import argparse
import numpy as np
from sklearn.metrics import normalized_mutual_info_score

import torch
from torch import nn

from dataloader import get_dataset
from kmeans import get_cluster_centers
from module import Encoder
from adverserial import adv_loss
from eval import predict, cluster_accuracy, balance
from utils import set_seed, AverageMeter, target_distribution, aff, inv_lr_scheduler
import os
import wandb  # Used to log progress and plot graphs. 
from vae import DFC_VAE
from vae import train as train_vae
from dfc import train as train_dfc
from dec import train as train_dec
from dfc import DFC
from resnet50_finetune import *
import torchvision.models as models

import pytorch_lightning as pl
from pl_bolts.models.autoencoders import VAE
import pandas as pd

In [32]:
from Args_notebook.py import args

## Define Functions

In [39]:

def get_encoder(args, log_name, legacy_path, path, dataloader_list, device='cpu', encoder_type='vae'):
    if encoder_type == 'vae':
        print('Loading the variational autoencoder')
        if legacy_path:
            encoder = Encoder().to(device)
            encoder.load_state_dict(torch.load(
                legacy_path, map_location=device))
        else:
            if path:
                model = DFC_VAE.load_from_checkpoint(path).to(device)
            else:
                model = train_vae(args, log_name,  dataloader_list, args.input_height,
                                  is_digit_dataset=args.digital_dataset, device=device).to(device)
            encoder = model.encoder
    elif encoder_type == 'resnet50':  # Maybe fine tune resnet50 here
        print('Loading the RESNET50 encoder')
        encoder = models.resnet50(pretrained=True, progress=True)
        
        set_parameter_requires_grad(encoder, req_grad=False)
        encoder.fc = nn.Linear(1000, args.dfc_hidden_dim) #TODO: Reshape and finetune resnet50        
        # get_update_param(encoder)
        encoder = encoder.to(device)
        # encoder, val_acc_history = train_last_layer_resnet50( #train for the 31 classes
            # encoder, dataloader_list, log_name=log_name, device=device, args=args, num_classes=args.dfc_hidden_dim)

    else:
        raise NameError('The encoder_type variable has an unvalid value')
    wandb.watch(encoder)
    return encoder


In [40]:

def get_dec(args, path, dataloader_list, encoder, save_name, device='cpu', centers=None):
    if path:
        dec = DFC(cluster_number=args.cluster_number,
                  hidden_dimension=args.dfc_hidden_dim).to(device)
        dec.load_state_dict(torch.load(path, map_location=device))
    else:
        dec = train_dec(args, dataloader_list, encoder, device,
                        centers=centers,  save_name=save_name)
    return dec


def get_dfc(args, path, dataloader_list, encoder, save_name, encoder_group_0=None, encoder_group_1=None, dfc_group_0=None, dfc_group_1=None, device='cpu', centers=None, get_loss_trade_off=lambda step: (10, 10, 10)):
    if path:
        dfc = DFC(cluster_number=args.cluster_number,
                  hidden_dimension=args.dfc_hidden_dim).to(device)
        dfc.load_state_dict(torch.load(path, map_location=device))
    else:
        dfc = train_dfc(args, dataloader_list, encoder, encoder_group_0, encoder_group_1, dfc_group_0, dfc_group_1,
                        device, centers=centers, get_loss_trade_off=get_loss_trade_off, save_name=save_name)
    return dfc


## Main Code

In [45]:
set_seed(args.seed)
os.makedirs(args.log_dir, exist_ok=True)
#Set wandb loging offline, avoid the need for an account.
os.environ["WANDB_MODE"] = "dryrun"
wandb.init(project="offline-run")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.set_device(args.gpu)
print(f"Using {device}")

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Offline run mode, not syncing to the cloud.
[34m[1mwandb[0m: W&B syncing is set to `offline` in this directory.  Run `wandb online` to enable cloud syncing.
Using cuda


In [44]:
dataloader_0, dataloader_1 = get_dataset[args.dataset](args)
print("Loading Encoder")
encoder = get_encoder(args, "encoder", args.encoder_legacy_path, args.encoder_path, [
                        dataloader_0, dataloader_1], device=device, encoder_type=args.encoder_type)

Global seed set to 2019
Loading Encoder
Loading the variational autoencoder


ValueError: You must call `wandb.init` before calling watch

## DFC Run
In this section we have the cells with all the steps to train a DFC.   
We first load the encoders used for the two dec, if no path is selected then we train new ones.
Next we load or train with K-means the cluster centers 

In [None]:
if args.method == 'dfc':
        print("Start pretraining individual golden standard DECs")
        print("loading the golden standard group 0 encoder")
        encoder_group_0 = get_encoder(args, "encoder_0", args.encoder_0_legacy_path, args.encoder_0_path, [
                                      dataloader_0], device=device, encoder_type=args.encoder_type)
        
        print("loading the golden standard group 1 encoder")
        encoder_group_1 = get_encoder(args, "encoder_1", args.encoder_1_legacy_path, args.encoder_1_path, [
                                      dataloader_1], device=device, encoder_type=args.encoder_type)
       
        cluster_centers_0 = None
        cluster_centers_1 = None
        if not args.dfc_0_path:
            # We don't have pretrained decs for both groups -> we have to generate cluster centers
            print("Load group 0 initial cluster definitions")
            cluster_centers_0 = get_cluster_centers(args, encoder_group_0, args.cluster_number, [
                                                    dataloader_0], args.cluster_0_path, device=device, save_name="clusters_0")

            print("Load group 1 initial cluster definitions")
            cluster_centers_1 = get_cluster_centers(args, encoder_group_1, args.cluster_number, [dataloader_1],
                                                    args.cluster_1_path, device=device, save_name="clusters_1")


In [None]:
if args.method == 'dfc':
        print("Train golden standard group 0 DEC")
        # note that the weight of the fairness losses are set to 0, making this a DEC instead of a DFC
        dfc_group_0 = get_dec(args, args.dfc_0_path, [
                              dataloader_0], encoder_group_0, "DEC_G0", device=device, centers=cluster_centers_0)

        print("Train golden standard group 1 DEC")
        # note that the weight of the fairness losses are set to 0, making this a DEC instead of a DFC
        dfc_group_1 = get_dec(args, args.dfc_1_path, [
                              dataloader_1], encoder_group_1, "DEC_G1", device=device, centers=cluster_centers_1)

        print("Load cluster centers for final DFC")
        cluster_centers = get_cluster_centers(args, encoder, args.cluster_number, [dataloader_0, dataloader_1],
                                              args.cluster_path, device=device, save_name="clusters_dfc")

        print("Train final DFC")

        loss_tradeoff = lambda _: (1, 1, 1)
        if args.dfc_tradeoff == 'no_fair':
            loss_tradeoff = lambda _: (0, 1, 1)
        elif args.dfc_tradeoff == 'no_struct':
            loss_tradeoff = lambda _: (1, 0, 1)

        dfc = get_dfc(args, args.dfc_path, [dataloader_0, dataloader_1], encoder, "DFC", encoder_group_0=encoder_group_0,
                      encoder_group_1=encoder_group_1, dfc_group_0=dfc_group_0, dfc_group_1=dfc_group_1, device=device,
                      centers=cluster_centers, get_loss_trade_off=loss_tradeoff)

In [None]:
if args.method == 'dec':
        print("Load cluster centers for final DEC")
        cluster_centers = get_cluster_centers(args, encoder, args.cluster_number, [dataloader_0, dataloader_1],
                                              args.cluster_path, device=device, save_name="clusters_dec")

        print("Train final DEC")
        dec = get_dec(args, None, [dataloader_0, dataloader_1],
                      encoder, "DEC", device=device, centers=cluster_centers)
   