In [1]:
import argparse
import os, sys
import os.path as osp
import torchvision
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import network, loss
from torch.utils.data import DataLoader
from data_list import ImageList, ImageList_idx
import random, pdb, math, copy
from tqdm import tqdm
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
import torch.nn.functional as F
import random

In [2]:
parser = argparse.ArgumentParser(description='Ours')
parser.add_argument('--gpu_id',
                    type=str,
                    nargs='?',
                    default='8',
                    help="device id to run")
parser.add_argument('--s', type=int, default=0, help="source")
parser.add_argument('--t', type=int, default=1, help="target")
parser.add_argument('--max_epoch',
                    type=int,
                    default=15,
                    help="max iterations")
parser.add_argument('--interval', type=int, default=15)
parser.add_argument('--batch_size',
                    type=int,
                    default=64,
                    help="batch_size")
parser.add_argument('--worker',
                    type=int,
                    default=4,
                    help="number of workers")
parser.add_argument(
    '--dset',
    type=str,
    default='visda-2017')
parser.add_argument('--lr', type=float, default=1e-3, help="learning rate")
parser.add_argument('--net',
                    type=str,
                    default='resnet101')
parser.add_argument('--seed', type=int, default=2020, help="random seed")

parser.add_argument('--bottleneck', type=int, default=256)
parser.add_argument('--epsilon', type=float, default=1e-5)
parser.add_argument('--layer',
                    type=str,
                    default="wn",
                    choices=["linear", "wn"])
parser.add_argument('--classifier',
                    type=str,
                    default="bn",
                    choices=["ori", "bn"])
parser.add_argument('--output', type=str, default='hat/target/')
parser.add_argument('--output_src', type=str, default='hat/source/')
parser.add_argument('--da',
                    type=str,
                    default='uda')
parser.add_argument('--issave', type=bool, default=True)
args = parser.parse_args(args=[])

In [3]:
if args.dset == 'visda-2017':
    names = ['train', 'validation']
    args.class_num = 12

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
'''SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
'''
for i in range(len(names)):
    if i == args.s:
        continue
    args.t = i

    folder = './data/'
    args.s_dset_path = folder + args.dset + '/' + names[
        args.s] + '_list.txt'
    args.t_dset_path = folder + args.dset + '/' + names[
        args.t] + '_list.txt'
    args.test_dset_path = folder + args.dset + '/' + names[
        args.t] + '_list.txt'

    args.output_dir_src = osp.join(args.output_src, args.da, args.dset,
                                    names[args.s][0].upper())
    args.output_dir = osp.join(
        args.output, args.da, args.dset,
        names[args.s][0].upper() + names[args.t][0].upper())
    args.name = names[args.s][0].upper() + names[args.t][0].upper()

    if not osp.exists(args.output_dir):
        os.system('mkdir -p ' + args.output_dir)
    if not osp.exists(args.output_dir):
        os.mkdir(args.output_dir)


In [30]:
def image_train(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ])


def image_test(resize_size=256, crop_size=224, alexnet=False):
    if not alexnet:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    else:
        normalize = Normalize(meanfile='./ilsvrc_2012_mean.npy')
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(), normalize
    ])


def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size
    txt_src = open(args.s_dset_path).readlines()
    txt_tar = open(args.t_dset_path).readlines()
    txt_test = open(args.test_dset_path).readlines()

    if not args.da == 'uda':
        label_map_s = {}
        for i in range(len(args.src_classes)):
            label_map_s[args.src_classes[i]] = i

        new_tar = []
        for i in range(len(txt_tar)):
            rec = txt_tar[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.tar_classes:
                if int(reci[1]) in args.src_classes:
                    line = reci[0] + ' ' + str(label_map_s[int(
                        reci[1])]) + '\n'
                    new_tar.append(line)
                else:
                    line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
                    new_tar.append(line)
        txt_tar = new_tar.copy()
        txt_test = txt_tar.copy()

    dsize = len(txt_src)
    tr_size = int(0.9*dsize)
    # print(dsize, tr_size, dsize - tr_size)
    tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])

    dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
    dset_loaders["source_tr"] = DataLoader(dsets["source_tr"],
                                           batch_size=64,
                                           shuffle=True,
                                           num_workers=args.worker,
                                           drop_last=False)
    dsets["source_te"] = ImageList(te_txt, transform=image_test())
    dset_loaders["source_te"] = DataLoader(dsets["source_te"],
                                           batch_size=1,
                                           shuffle=False,
                                           num_workers=args.worker,
                                           drop_last=False)
    dsets["target"] = ImageList_idx(txt_tar, transform=image_train())
    dset_loaders["target"] = DataLoader(dsets["target"],
                                        batch_size=64,
                                        shuffle=True,
                                        num_workers=args.worker,
                                        drop_last=False)
    dsets["test"] = ImageList_idx(txt_test, transform=image_test())
    dset_loaders["test"] = DataLoader(dsets["test"],
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=args.worker,
                                      drop_last=False)

    return dset_loaders



In [5]:
args.output_dir_src

'hat/source/uda/visda-2017/T'

In [32]:
dset_loaders = data_load(args)
## set base network
netF = network.ResBase(res_name=args.net).cuda()
netB = network.feat_bootleneck_sdaE(type=args.classifier,
                                feature_dim=netF.in_features,
                                bottleneck_dim=args.bottleneck).cuda()
netC = network.feat_classifier(type=args.layer,
                                class_num=args.class_num,
                                bottleneck_dim=args.bottleneck).cuda()


modelpath = 'hat/target/uda/visda-2017/TV/' + '/target_F_final.pt'
netF.load_state_dict(torch.load(modelpath))
modelpath = 'hat/target/uda/visda-2017/TV/' + '/target_B_final.pt'
netB.load_state_dict(torch.load(modelpath))
modelpath = 'hat/target/uda/visda-2017/TV/' + '/target_C_final.pt'
netC.load_state_dict(torch.load(modelpath))
netF.eval()
netB.eval()
netC.eval()

feat_classifier(
  (fc): Linear(in_features=256, out_features=12, bias=True)
)

# Domain Classifier

In [17]:
class CLS_D(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer=nn.Sequential(nn.Linear(256,64),nn.ReLU(),nn.Linear(64,2))
    def forward(self,x):
        return self.layer(x)

In [18]:
d_cls=CLS_D().cuda()


In [19]:
# random sampling data for training domain classifier
loader_t = dset_loaders['target']
iter_test_t = iter(loader_t)
data_t = iter_test_t.next()
#data_t = iter_test_t.next()
data_t = iter_test_t.next()
input_t = data_t[0].cuda()
label_t = data_t[1].cuda()

loader_s = dset_loaders['source_tr']
iter_test_s = iter(loader_s)
#data_s = iter_test_s.next()
data_s = iter_test_s.next()
data_s = iter_test_s.next()
data_s = iter_test_s.next()
input_s = data_s[0].cuda()
label_s = data_s[1].cuda()

In [20]:
with torch.no_grad():
    torch.cuda.empty_cache()

    input_t=netB.bn(netB.bottleneck(netF(input_t)))
    label_t=torch.ones(input_t.shape[0]).long().cuda()

    torch.cuda.empty_cache()

    input_s=netB.bn(netB.bottleneck(netF(input_s)))
    label_s=torch.zeros(input_s.shape[0]).long().cuda()
        

In [21]:
inputs=torch.cat((input_t,input_s))
labels=torch.cat((label_t,label_s))
inputs_np=inputs.cpu().numpy()
labels_np=labels.cpu().numpy()
state = np.random.get_state()
np.random.shuffle(inputs_np)
np.random.set_state(state)
np.random.shuffle(labels_np)

In [22]:
inputs=torch.from_numpy(inputs_np).cuda()
labels=torch.from_numpy(labels_np).cuda().long()

In [24]:
optim=torch.optim.SGD(d_cls.parameters(),lr=0.01)
for i in range(200):
    output=d_cls(inputs)
    loss=nn.CrossEntropyLoss()(output,labels)
    optim.zero_grad()
    loss.backward()
    optim.step()

    output=d_cls(inputs)
    _,pred=torch.max(output,1)
    accuracy = torch.sum(
        torch.squeeze(pred).float() == labels).item() / float(
            labels.size()[0])
    print(loss,accuracy)

tensor(0.6949, device='cuda:0', grad_fn=<NllLossBackward>) 0.5546875
tensor(0.6809, device='cuda:0', grad_fn=<NllLossBackward>) 0.5703125
tensor(0.6677, device='cuda:0', grad_fn=<NllLossBackward>) 0.6015625
tensor(0.6552, device='cuda:0', grad_fn=<NllLossBackward>) 0.625
tensor(0.6434, device='cuda:0', grad_fn=<NllLossBackward>) 0.671875
tensor(0.6323, device='cuda:0', grad_fn=<NllLossBackward>) 0.7578125
tensor(0.6217, device='cuda:0', grad_fn=<NllLossBackward>) 0.765625
tensor(0.6115, device='cuda:0', grad_fn=<NllLossBackward>) 0.84375
tensor(0.6016, device='cuda:0', grad_fn=<NllLossBackward>) 0.8828125
tensor(0.5921, device='cuda:0', grad_fn=<NllLossBackward>) 0.8984375
tensor(0.5829, device='cuda:0', grad_fn=<NllLossBackward>) 0.890625
tensor(0.5739, device='cuda:0', grad_fn=<NllLossBackward>) 0.8828125
tensor(0.5653, device='cuda:0', grad_fn=<NllLossBackward>) 0.8828125
tensor(0.5568, device='cuda:0', grad_fn=<NllLossBackward>) 0.8828125
tensor(0.5487, device='cuda:0', grad_fn=<Nl

## accuracy of domain classifier

In [None]:
start_test = True
loader=dset_loaders['source_te']
with torch.no_grad():
    iter_test = iter(loader)
    for i in range(len(loader)):
        data = iter_test.next()
        inputs = data[0]
        labels = data[1]
        inputs = inputs.cuda()
        outputs = netB.bn(netB.bottleneck(netF(inputs)))
        outputs=d_cls(outputs)
        if start_test:
            all_output = outputs.float().cpu()
            all_label = labels.float()
            start_test = False
        else:
            all_output = torch.cat((all_output, outputs.float().cpu()), 0)
            all_label = torch.cat((all_label, labels.float()), 0)
all_label=torch.zeros_like(all_label).long()
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(
    torch.squeeze(predict).float() == all_label).item() / float(
        all_label.size()[0])
print(accuracy)

In [None]:
start_test = True
loader=dset_loaders['test']
with torch.no_grad():
    iter_test = iter(loader)
    for i in range(len(loader)):
        data = iter_test.next()
        inputs = data[0]
        labels = data[1]
        inputs = inputs.cuda()
        outputs = netB.bn(netB.bottleneck(netF(inputs)))
        outputs=d_cls(outputs)
        if start_test:
            all_output = outputs.float().cpu()
            all_label = labels.float()
            start_test = False
        else:
            all_output = torch.cat((all_output, outputs.float().cpu()), 0)
            all_label = torch.cat((all_label, labels.float()), 0)
all_label=torch.ones_like(all_label).long()
_, predict = torch.max(all_output, 1)
accuracy = torch.sum(
    torch.squeeze(predict).float() == all_label).item() / float(
        all_label.size()[0])
print(accuracy)

## Accuracy on source and target domain with the estimated domain ID

In [None]:
start_test = True
loader=dset_loaders['source_te']
with torch.no_grad():
    iter_test = iter(loader)
    for i in range(len(loader)):
        data = iter_test.next()
        inputs = data[0]
        labels = data[1]
        inputs = inputs.cuda()
        outputs = netB.bn(netB.bottleneck(netF(inputs)))
        outputs=d_cls(outputs)
        idx=torch.max(outputs,1)[-1].long().item()
        output_f, masks = netB(netF(inputs),t=idx)
        output = netC(output_f)
        softmax_out = nn.Softmax(dim=1)(output)
        if start_test:
            all_output = output.float().cpu()
            all_label = labels.float()
            start_test = False
        else:
            all_output = torch.cat((all_output, output.float().cpu()), 0)
            all_label = torch.cat((all_label, labels.float()), 0)
_, predict = torch.max(all_output, 1)
matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
acc = matrix.diagonal() / matrix.sum(axis=1) * 100
aacc = acc.mean()
aa = [str(np.round(i, 2)) for i in acc]
acc = ' '.join(aa)
print(aacc, acc)

In [None]:
start_test = True
loader=dset_loaders['test']
with torch.no_grad():
    iter_test = iter(loader)
    for i in range(len(loader)):
        data = iter_test.next()
        inputs = data[0]
        labels = data[1]
        inputs = inputs.cuda()
        outputs = netB.bn(netB.bottleneck(netF(inputs)))
        outputs=d_cls(outputs)
        idx=torch.max(outputs,1)[-1].long().item()
        output_f, masks = netB(netF(inputs),t=idx)
        output = netC(output_f)
        softmax_out = nn.Softmax(dim=1)(output)
        if start_test:
            all_output = output.float().cpu()
            all_label = labels.float()
            start_test = False
        else:
            all_output = torch.cat((all_output, output.float().cpu()), 0)
            all_label = torch.cat((all_label, labels.float()), 0)
_, predict = torch.max(all_output, 1)
matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
acc = matrix.diagonal() / matrix.sum(axis=1) * 100
aacc = acc.mean()
aa = [str(np.round(i, 2)) for i in acc]
acc = ' '.join(aa)
print(aacc, acc)