In [None]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import importlib
from IPython.display import display
import torch
from models.hdisc_msda import Disc_MSDANet, weighted_mse
from models.MDAN import MDANet_general
from utils.load_amazon import load_amazon
from utils.utils_hdisc import batch_loader, split_source_target

In [None]:
#Loading
X_amazon, y_amazon, domain_list = load_amazon(filepath='./data/amazon_tf-idf.npz', domains=None)
#Standardize labels
mu_y = np.mean(np.concatenate(y_amazon))
std_y = np.std(np.concatenate(y_amazon))
print(std_y)
y_amazon = [(y-mu_y)/std_y for y in y_amazon]
#y_amazon = [y-3 for y in y_amazon]
#Number of domains
n_domains = len(X_amazon)

#Seed for reproducible results
np.random.seed(0)
torch.manual_seed(0)

def get_feature_extractor(input_dim=X_amazon[0].shape[1]):
    return nn.ModuleList([
            nn.Linear(1000, 500, bias=False), nn.ELU(), nn.Dropout(p=0.1),
            nn.Linear(500, 20, bias=False), nn.ELU(), nn.Dropout(p=0.1)])

def get_predictor(output_dim=1):
    return  nn.ModuleList([
            #nn.Linear(100,10, bias=False), nn.ELU(), nn.Dropout(p=0.1),
            nn.Linear(20, output_dim, bias=False)])

def get_discriminator(output_dim=1):
    return nn.ModuleList([
            #nn.Linear(100, 10, bias=False), nn.ELU(), nn.Dropout(p=0.1),
            nn.Linear(20, output_dim, bias=False)])
    
    
def save_result(result, filepath, domain_list):
    median = {}
    for k in domain_list:
        median[k] = []
    for r in range(len(result)):
        for k in domain_list:
            median[k].append(result[r][k])
    df = pd.DataFrame(median)
    df.to_csv(filepath)



In [None]:
import importlib
import models.hdisc_msda 
importlib.reload(models.hdisc_msda)
from models.hdisc_msda import Disc_MSDANet

#Number of experiments to launch
nb_experiments = 5
results_mse, results_mae = [], []

params= {'input_dim': X_amazon[0].shape[1], 'output_dim': 1, 'n_sources': n_domains-1, 'loss': torch.nn.MSELoss(),
         'weighted_loss': weighted_mse, 'min_pred': -np.inf, 'max_pred': np.inf}

#Number of epochs
epochs_pretrain, epochs_adapt = 0, 100
epochs_h_disc, epochs_feat, epochs_alpha, epochs_pred = 1, 1, 1, 1

device = torch.device('cuda:0')
lr = 0.001
batch_size = 128
for exp in range(nb_experiments):
    print('\n ----------------------------- %i / %i -----------------------------'%(exp+1, nb_experiments))
    mse_list, mae_list =  {}, {}
    alphas = {}
    for i in range(len(domain_list)):
        domain = domain_list[i]
        #Split source and target
        X_s, X_t, y_s, y_t = split_source_target(X_amazon, y_amazon, i, device, merge=False)
        #Initialize model
        params['feature_extractor'] = get_feature_extractor()
        params['h_pred'] = get_predictor(output_dim=1)
        params['h_disc'] = get_discriminator(output_dim=1)
        model = Disc_MSDANet(params).to(device)
        opt_feat = torch.optim.Adam([{'params': model.feature_extractor.parameters()}],lr=lr)
        opt_pred = torch.optim.Adam([{'params': model.h_pred.parameters()}],lr=lr)
        opt_disc = torch.optim.Adam([{'params': model.h_disc.parameters()}],lr=lr)
        opt_alpha = torch.optim.Adam([{'params': model.alpha}],lr=lr)
        model.optimizers(opt_feat, opt_pred, opt_disc, opt_alpha)
        print('----', domain, '----')
        #Pre-training
        print('------------Pre-training------------')
        for epoch in range(epochs_pretrain):
            loader = batch_loader(X_s, y_s ,batch_size = batch_size)
            for x_bs, y_bs in loader:
                loss_pred = model.train_prediction(x_bs, X_t, y_bs, clip=1, pred_only=False)
            if (epoch+1)%1==0:
                source_loss, disc = model.compute_loss(X_s, X_t, y_s)
                reg_loss = model.loss(y_t, model.predict(X_t))
                print('Epoch: %i/%i ; Train loss: %.3f ; Disc: %.3f ; Test loss: %.3f'%(epoch+1, epochs_pretrain, source_loss.item(), disc.item(), reg_loss.item()))

        #Alternated training
        print('------------Alternated training------------')
        for epoch in range(epochs_adapt):
            model.train()
            loader = batch_loader(X_s, y_s ,batch_size = batch_size)
            for x_bs, y_bs in loader:
                ridx = np.random.choice(X_t.shape[0], batch_size)
                x_bt = X_t[ridx,:]
                #Train h to minimize source loss
                for e in range(epochs_pred):
                    model.train_prediction(x_bs, x_bt, y_bs, pred_only=False)

                #Train h' to maximize discrepancy
                for e in range(epochs_h_disc):
                    model.train_h_discrepancy(x_bs, x_bt, y_bs)
                
                #Train phi to minimize discrepancy
                for e in range(epochs_feat):
                    model.train_feat_discrepancy(x_bs, x_bt, y_bs, mu=0)
                    
                #Train alpha to minimize discrepancy
                for e in range(epochs_alpha):
                    model.train_alpha_discrepancy(x_bs, x_bt, y_bs, clip=1, lam_alpha=0.1)

            #Logs
            if (epoch+1)%100==0:
                model.eval()
                source_loss, disc = model.compute_loss(X_s, X_t, y_s)
                reg_loss = model.loss(y_t, model.predict(X_t))
                print('Epoch: %i/%i (h_pred); Train loss: %.3f ; Disc: %.3f ; Test loss: %.3f'%(epoch+1, epochs_adapt, source_loss.item(), disc.item(), reg_loss.item()))
        mse_list[domain] = model.loss(y_t, model.predict(X_t)).item()
        mae_list[domain] = torch.sum(torch.abs(y_t.squeeze_()- model.predict(X_t).squeeze_())).item()/y_t.shape[0]
        print(mae_list[domain])
    results_mse.append(mse_list)
    results_mae.append(mae_list)
    save_result(results_mse, './results/ADisc_MSDA_mse.csv', domain_list)
    save_result(results_mae, './results/ADisc_MSDA_mae.csv', domain_list)
save_result(results_mse, './results/ADisc_MSDA_mse.csv', domain_list)
save_result(results_mae, './results/ADisc_MSDA_mae.csv', domain_list)
for key, val in alphas.items():
    w.create_dataset(name=k, data=val)
w.close()