In [1]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from utils import datasets, utils
import os
import time

from methods.digits import ours_beta

In [2]:
x_train, x_test, y_train, y_test, classes = datasets.load('digits',
                                                          imbalance=datasets.HIGH_IMBALANCE_AMOUNTS_DIGITS_DIFFERENT_PY)

domains = x_train.keys()

In [3]:
for domain in domains:
    print(domain, 'train', x_train[domain].shape)
    print(np.unique(y_train[domain], return_counts=True))
    print(domain, 'test', x_test[domain].shape)
    print(np.unique(y_test[domain], return_counts=True))
    print()

mnist train (2370, 3, 28, 28)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([500, 500, 500, 500, 200, 100,  40,  10,  10,  10]))
mnist test (10000, 3, 28, 28)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 980, 1135, 1032, 1010,  982,  892,  958, 1028,  974, 1009]))

mnistm train (2370, 3, 28, 28)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 10, 200, 100, 500,  40, 500,  10, 500, 500,  10]))
mnistm test (10000, 3, 28, 28)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 980, 1135, 1032, 1010,  982,  892,  958, 1028,  974, 1009]))

svhn train (2370, 3, 32, 32)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([200,  10,  10, 500, 500, 100,  40,  10, 500, 500]))
svhn test (26032, 3, 32, 32)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([1744, 5099, 4149, 2882, 2523, 2384, 1977, 2019, 1660, 1595]))

syn train (2370, 3, 32, 32)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([500, 100, 500,  10,  10,  10, 500, 500,  40, 200]))
syn test (9553, 3, 32, 32)
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),

## Parameters

In [4]:
domain_t = 'mnist'

lr = 1e-4
batch_size = 128
epochs = 150
m_params = (epochs, batch_size, lr, 3, len(classes))
evaluation_metrics = ['acc', 'bal_acc', 'auc', 'f1']

## Select source domains

In [5]:
x_sources, y_sources = {}, {}
for domain_s in domains:
    if domain_s != domain_t:
        x_sources[domain_s] = x_train[domain_s]
        y_sources[domain_s] = y_train[domain_s]

## Method execution

In [6]:
## Instantiate method
method = ours_beta.Method(*m_params)

## Train method
method.train(x_sources, y_sources, x_train[domain_t], y_train[domain_t], domain_t=domain_t)

## Obtain predictions from trained method on target domain test set
y_probas = method.predict(x_test[domain_t])

## Eval results given evaluation metrics
met = utils.eval_results(y_probas, y_test[domain_t], evaluation_metrics)

Epoch 1/150 - Time elapsed 00:00:04
Epoch 10/150 - Time elapsed 00:00:28
Epoch 20/150 - Time elapsed 00:00:54
Epoch 30/150 - Time elapsed 00:01:18
Epoch 40/150 - Time elapsed 00:01:44
Epoch 50/150 - Time elapsed 00:02:08
Epoch 60/150 - Time elapsed 00:02:31
Epoch 70/150 - Time elapsed 00:02:56
Epoch 80/150 - Time elapsed 00:03:22
Epoch 90/150 - Time elapsed 00:03:45
Epoch 100/150 - Time elapsed 00:04:10
Epoch 110/150 - Time elapsed 00:04:35
Epoch 120/150 - Time elapsed 00:05:01
Epoch 130/150 - Time elapsed 00:05:28
Epoch 140/150 - Time elapsed 00:05:54
Epoch 150/150 - Time elapsed 00:06:18


In [7]:
met

{'acc': 94.1,
 'bal_acc': 94.0846800722144,
 'auc': 99.8618771859756,
 'f1': 93.95243334841766}