In [None]:
import argparse
import os
import os.path as osp

import numpy as np
import torch
print(torch.cuda.is_available())

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import network
import loss
import pre_process as prep
import lr_schedule
from pre_process import ImageList, image_classification_test
import copy
import random



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Code for RSDA-MSTN')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
    parser.add_argument('--source', type=str, default='amazon',choices=["amazon", "dslr","webcam"])
    parser.add_argument('--target', type=str, default='dslr', choices=["amazon", "dslr", "webcam"])
    parser.add_argument('--test_interval', type=int, default=50, help="interval of two continuous test phase")
    parser.add_argument('--snapshot_interval', type=int, default=1000, help="interval of two continuous output model")
    parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
    parser.add_argument('--stages', type=int, default=6, help="training stages")
    args = parser.parse_args([])
    s_dset_path = 'data/office/' + args.source + '_list.txt' #'../../data/office/' + args.source + '_list.txt'
    t_dset_path = 'data/office/' + args.target + '_list.txt' #'../../data/office/' + args.target + '_list.txt'

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    config = {}
    config["source"] = args.source
    config["target"] = args.target
    config["gpu"] = args.gpu_id
    config["test_interval"] = args.test_interval
    config["snapshot_interval"] = args.snapshot_interval
    config["output_for_test"] = True
    config["output_path"] = "snapshot/init"
    if not osp.exists(config["output_path"]):
        os.makedirs(config["output_path"])
    config["out_file"] = open(osp.join(config["output_path"],args.source+"_"+args.target+ "_log.txt"), "w")

    config["prep"] = {'params':{"resize_size":256, "crop_size":224}}
    config["network"] = {"name":network.ResNet50, \
            "params":{"new_cls":True,"feature_dim":256,"class_num":31} }
    config["optimizer"] = {"type":optim.SGD, "optim_params":{'lr':args.lr, "momentum":0.9, \
                           "weight_decay":0.0005, "nesterov":True}, "lr_type":"inv", \
                           "lr_param":{"lr":args.lr, "gamma":0.001, "power":0.75} }
    config["data"] = {"source":{"list_path":s_dset_path, "batch_size":36}, \
                      "target":{"list_path":t_dset_path, "batch_size":36}, \
                      "test":{"list_path":t_dset_path, "batch_size":72}}
    config["out_file"].flush()
    if config["source"] == "amazon" and config["target"] == "dslr":
        config["iterations"] = 2000
        seed = 0
    elif config["source"] == "amazon" and config["target"] == "webcam":
        config["iterations"] = 2000
        seed = 0
    else:
        config["iterations"] = 4000
        seed = 1


    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    config["out_file"].write('\n--- initialization ---\n')
    source = config["source"]
    target = config["target"]
    prep_dict = {}
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])

    prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                              transform=prep_dict["test"])
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                      shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    ad_net = ad_net.cuda()

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net)
        base_network = nn.DataParallel(base_network)

    parameter_classifier = [base_network.get_parameters()[1]]
    parameter_feature = base_network.get_parameters()[0:1] + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer_classfier = optimizer_config["type"](parameter_classifier, \
                                                   **(optimizer_config["optim_params"]))
    optimizer_feature = optimizer_config["type"](parameter_feature, \
                                                 **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer_feature.param_groups:
        param_lr.append(param_group["lr"])
    param_lr.append(optimizer_classfier.param_groups[0]["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(base_network)

    Cs_memory=torch.zeros(class_num,256).cuda()
    Ct_memory=torch.zeros(class_num,256).cuda()


    for i in range(config["iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders,base_network)
            temp_model = base_network
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(temp_model)
            log_str = "iter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(i, temp_acc, best_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
        if (i + 1) % config["snapshot_interval"] == 0:
            if not os.path.exists("save/init_model"):
                os.makedirs("save/init_model")
            torch.save(best_model, 'save/init_model/' + source + '_' + target + '.pkl')

        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer_classfier = lr_scheduler(optimizer_classfier, i, **schedule_param)
        optimizer_feature = lr_scheduler(optimizer_feature, i, **schedule_param)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)
        features_target, outputs_target = base_network(inputs_target)
        features = torch.cat((features_source, features_target), dim=0)
        transfer_loss = loss.DANN(features, ad_net)
        pseu_labels_target=torch.argmax(outputs_target,dim=1)
        
        loss_sm,Cs_memory,Ct_memory=loss.SM(features_source,features_target,labels_source,pseu_labels_target,
                                            Cs_memory,Ct_memory)
        gamma=network.calc_coeff(i)
        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)

        loss_total = classifier_loss  + gamma * (loss_sm) + transfer_loss

        optimizer_classfier.zero_grad()
        optimizer_feature.zero_grad()

        loss_total.backward()
        optimizer_feature.step()
        optimizer_classfier.step()

        print('step:{: d},\t,class_loss:{:.4f},\t,trans_loss:{:.4f},\t,sm:{:.2f}'
              ''.format(i, classifier_loss.item(),transfer_loss.item(),loss_sm.item()))
        Cs_memory.detach_()
        Ct_memory.detach_()

True
step: 0,	,class_loss:3.6366,	,trans_loss:0.6927,	,sm:28.31
step: 1,	,class_loss:3.4498,	,trans_loss:0.6928,	,sm:24.96
step: 2,	,class_loss:3.4750,	,trans_loss:0.6931,	,sm:24.05
step: 3,	,class_loss:3.5438,	,trans_loss:0.6930,	,sm:22.64
step: 4,	,class_loss:3.4941,	,trans_loss:0.6935,	,sm:22.27
step: 5,	,class_loss:3.5317,	,trans_loss:0.6926,	,sm:21.70
step: 6,	,class_loss:3.4987,	,trans_loss:0.6932,	,sm:21.59
step: 7,	,class_loss:3.5111,	,trans_loss:0.6927,	,sm:20.13
step: 8,	,class_loss:3.3170,	,trans_loss:0.6926,	,sm:19.32
step: 9,	,class_loss:3.3012,	,trans_loss:0.6935,	,sm:18.54
step: 10,	,class_loss:3.3449,	,trans_loss:0.6933,	,sm:18.30
step: 11,	,class_loss:3.3855,	,trans_loss:0.6936,	,sm:18.40
step: 12,	,class_loss:3.3423,	,trans_loss:0.6921,	,sm:16.18
step: 13,	,class_loss:3.3392,	,trans_loss:0.6930,	,sm:16.05
step: 14,	,class_loss:3.2671,	,trans_loss:0.6931,	,sm:15.75
step: 15,	,class_loss:3.2580,	,trans_loss:0.6919,	,sm:15.67
step: 16,	,class_loss:3.2196,	,trans_loss:0.6

step: 137,	,class_loss:1.1509,	,trans_loss:0.6884,	,sm:3.97
step: 138,	,class_loss:1.3279,	,trans_loss:0.6885,	,sm:3.88
step: 139,	,class_loss:0.8585,	,trans_loss:0.6894,	,sm:3.48
step: 140,	,class_loss:1.3719,	,trans_loss:0.6889,	,sm:3.43
step: 141,	,class_loss:1.0336,	,trans_loss:0.6905,	,sm:3.27
step: 142,	,class_loss:1.0333,	,trans_loss:0.6904,	,sm:3.24
step: 143,	,class_loss:1.2052,	,trans_loss:0.6896,	,sm:3.19
step: 144,	,class_loss:1.0123,	,trans_loss:0.6884,	,sm:3.44
step: 145,	,class_loss:1.0596,	,trans_loss:0.6905,	,sm:3.59
step: 146,	,class_loss:0.7438,	,trans_loss:0.6882,	,sm:3.76
step: 147,	,class_loss:0.9929,	,trans_loss:0.6889,	,sm:3.76
step: 148,	,class_loss:0.9653,	,trans_loss:0.6885,	,sm:3.61
iter: 00149, 	 precision: 0.6847,	 best_acc:0.6847
step: 149,	,class_loss:1.0137,	,trans_loss:0.6894,	,sm:3.76
step: 150,	,class_loss:1.0057,	,trans_loss:0.6879,	,sm:3.69
step: 151,	,class_loss:0.9721,	,trans_loss:0.6888,	,sm:3.63
step: 152,	,class_loss:1.1143,	,trans_loss:0.6894

step: 271,	,class_loss:0.4233,	,trans_loss:0.6863,	,sm:2.50
step: 272,	,class_loss:0.9763,	,trans_loss:0.6855,	,sm:2.64
step: 273,	,class_loss:0.5319,	,trans_loss:0.6849,	,sm:2.81
step: 274,	,class_loss:0.5125,	,trans_loss:0.6870,	,sm:2.80
step: 275,	,class_loss:0.5801,	,trans_loss:0.6846,	,sm:2.64
step: 276,	,class_loss:0.3739,	,trans_loss:0.6863,	,sm:2.50
step: 277,	,class_loss:0.5594,	,trans_loss:0.6839,	,sm:2.41
step: 278,	,class_loss:0.7908,	,trans_loss:0.6839,	,sm:2.29
step: 279,	,class_loss:0.6448,	,trans_loss:0.6859,	,sm:2.42
step: 280,	,class_loss:0.3159,	,trans_loss:0.6837,	,sm:2.47
step: 281,	,class_loss:0.3165,	,trans_loss:0.6837,	,sm:2.44
step: 282,	,class_loss:0.6687,	,trans_loss:0.6861,	,sm:2.23
step: 283,	,class_loss:0.3997,	,trans_loss:0.6857,	,sm:2.33
step: 284,	,class_loss:0.5728,	,trans_loss:0.6863,	,sm:2.17
step: 285,	,class_loss:0.3834,	,trans_loss:0.6839,	,sm:2.21
step: 286,	,class_loss:0.3599,	,trans_loss:0.6825,	,sm:2.13
step: 287,	,class_loss:0.8252,	,trans_lo

step: 405,	,class_loss:0.4479,	,trans_loss:0.6824,	,sm:1.77
step: 406,	,class_loss:0.4068,	,trans_loss:0.6820,	,sm:1.79
step: 407,	,class_loss:0.4933,	,trans_loss:0.6804,	,sm:1.89
step: 408,	,class_loss:0.5439,	,trans_loss:0.6832,	,sm:1.78
step: 409,	,class_loss:0.3978,	,trans_loss:0.6818,	,sm:1.78
step: 410,	,class_loss:0.3266,	,trans_loss:0.6811,	,sm:1.82
step: 411,	,class_loss:0.5112,	,trans_loss:0.6840,	,sm:1.83
step: 412,	,class_loss:0.4274,	,trans_loss:0.6814,	,sm:1.84
step: 413,	,class_loss:0.4183,	,trans_loss:0.6801,	,sm:1.80
step: 414,	,class_loss:0.2824,	,trans_loss:0.6866,	,sm:1.71
step: 415,	,class_loss:0.3589,	,trans_loss:0.6851,	,sm:1.78
step: 416,	,class_loss:0.2482,	,trans_loss:0.6814,	,sm:1.74
step: 417,	,class_loss:0.3989,	,trans_loss:0.6813,	,sm:1.66
step: 418,	,class_loss:0.2875,	,trans_loss:0.6837,	,sm:1.56
step: 419,	,class_loss:0.3805,	,trans_loss:0.6842,	,sm:1.57
step: 420,	,class_loss:0.2496,	,trans_loss:0.6806,	,sm:1.88
step: 421,	,class_loss:0.2719,	,trans_lo

step: 540,	,class_loss:0.2583,	,trans_loss:0.6859,	,sm:1.61
step: 541,	,class_loss:0.5185,	,trans_loss:0.6865,	,sm:1.64
step: 542,	,class_loss:0.2725,	,trans_loss:0.6826,	,sm:1.49
step: 543,	,class_loss:0.2757,	,trans_loss:0.6812,	,sm:1.51
step: 544,	,class_loss:0.3131,	,trans_loss:0.6833,	,sm:1.61
step: 545,	,class_loss:0.3226,	,trans_loss:0.6904,	,sm:1.54
step: 546,	,class_loss:0.2938,	,trans_loss:0.6841,	,sm:1.61
step: 547,	,class_loss:0.0721,	,trans_loss:0.6846,	,sm:1.70
step: 548,	,class_loss:0.0914,	,trans_loss:0.6839,	,sm:1.67
iter: 00549, 	 precision: 0.8032,	 best_acc:0.8032
step: 549,	,class_loss:0.2771,	,trans_loss:0.6845,	,sm:1.81
step: 550,	,class_loss:0.2705,	,trans_loss:0.6854,	,sm:1.59
step: 551,	,class_loss:0.3501,	,trans_loss:0.6852,	,sm:1.65
step: 552,	,class_loss:0.2185,	,trans_loss:0.6868,	,sm:1.77
step: 553,	,class_loss:0.2330,	,trans_loss:0.6802,	,sm:1.67
step: 554,	,class_loss:0.2301,	,trans_loss:0.6853,	,sm:1.51
step: 555,	,class_loss:0.2643,	,trans_loss:0.6877

step: 674,	,class_loss:0.1582,	,trans_loss:0.6909,	,sm:1.25
step: 675,	,class_loss:0.2248,	,trans_loss:0.6865,	,sm:1.26
step: 676,	,class_loss:0.1918,	,trans_loss:0.6878,	,sm:1.32
step: 677,	,class_loss:0.1871,	,trans_loss:0.6876,	,sm:1.28
step: 678,	,class_loss:0.3501,	,trans_loss:0.6894,	,sm:1.31
step: 679,	,class_loss:0.2540,	,trans_loss:0.6880,	,sm:1.37
step: 680,	,class_loss:0.2294,	,trans_loss:0.6894,	,sm:1.33
step: 681,	,class_loss:0.1627,	,trans_loss:0.6902,	,sm:1.29
step: 682,	,class_loss:0.2091,	,trans_loss:0.6915,	,sm:1.22
step: 683,	,class_loss:0.3291,	,trans_loss:0.6912,	,sm:1.31
step: 684,	,class_loss:0.2887,	,trans_loss:0.6914,	,sm:1.40
step: 685,	,class_loss:0.1968,	,trans_loss:0.6856,	,sm:1.46
step: 686,	,class_loss:0.2849,	,trans_loss:0.6901,	,sm:1.32
step: 687,	,class_loss:0.1438,	,trans_loss:0.6895,	,sm:1.32
step: 688,	,class_loss:0.1624,	,trans_loss:0.6901,	,sm:1.27
step: 689,	,class_loss:0.1585,	,trans_loss:0.6878,	,sm:1.26
step: 690,	,class_loss:0.1601,	,trans_lo

step: 808,	,class_loss:0.1963,	,trans_loss:0.6959,	,sm:1.20
step: 809,	,class_loss:0.1691,	,trans_loss:0.6950,	,sm:1.20
step: 810,	,class_loss:0.1713,	,trans_loss:0.6924,	,sm:1.06
step: 811,	,class_loss:0.2000,	,trans_loss:0.6936,	,sm:1.17
step: 812,	,class_loss:0.2784,	,trans_loss:0.6990,	,sm:1.27
step: 813,	,class_loss:0.3460,	,trans_loss:0.6966,	,sm:1.29
step: 814,	,class_loss:0.2385,	,trans_loss:0.6952,	,sm:1.27
step: 815,	,class_loss:0.1999,	,trans_loss:0.6990,	,sm:1.07
step: 816,	,class_loss:0.1438,	,trans_loss:0.6974,	,sm:1.10
step: 817,	,class_loss:0.1493,	,trans_loss:0.6954,	,sm:1.08
step: 818,	,class_loss:0.2428,	,trans_loss:0.6943,	,sm:0.96
step: 819,	,class_loss:0.1305,	,trans_loss:0.6964,	,sm:1.06
step: 820,	,class_loss:0.2475,	,trans_loss:0.6951,	,sm:1.10
step: 821,	,class_loss:0.1259,	,trans_loss:0.6956,	,sm:1.16
step: 822,	,class_loss:0.1956,	,trans_loss:0.6918,	,sm:1.15
step: 823,	,class_loss:0.1984,	,trans_loss:0.6937,	,sm:1.02
step: 824,	,class_loss:0.0997,	,trans_lo

step: 943,	,class_loss:0.1619,	,trans_loss:0.6986,	,sm:0.95
step: 944,	,class_loss:0.0659,	,trans_loss:0.6988,	,sm:0.88
step: 945,	,class_loss:0.0611,	,trans_loss:0.6969,	,sm:0.90
step: 946,	,class_loss:0.2344,	,trans_loss:0.6999,	,sm:0.95
step: 947,	,class_loss:0.1089,	,trans_loss:0.6963,	,sm:0.85
step: 948,	,class_loss:0.1624,	,trans_loss:0.6989,	,sm:1.13
iter: 00949, 	 precision: 0.8273,	 best_acc:0.8394
step: 949,	,class_loss:0.0831,	,trans_loss:0.6958,	,sm:1.09
step: 950,	,class_loss:0.1633,	,trans_loss:0.7002,	,sm:1.01
step: 951,	,class_loss:0.1223,	,trans_loss:0.6999,	,sm:0.98
step: 952,	,class_loss:0.0574,	,trans_loss:0.6983,	,sm:0.85
step: 953,	,class_loss:0.0863,	,trans_loss:0.6971,	,sm:0.83
step: 954,	,class_loss:0.0533,	,trans_loss:0.6960,	,sm:0.84
step: 955,	,class_loss:0.1129,	,trans_loss:0.6965,	,sm:0.86
step: 956,	,class_loss:0.1195,	,trans_loss:0.6954,	,sm:0.96
step: 957,	,class_loss:0.1151,	,trans_loss:0.6964,	,sm:0.93
step: 958,	,class_loss:0.1204,	,trans_loss:0.7005

step: 1076,	,class_loss:0.1539,	,trans_loss:0.6969,	,sm:1.01
step: 1077,	,class_loss:0.1197,	,trans_loss:0.6967,	,sm:0.89
step: 1078,	,class_loss:0.0809,	,trans_loss:0.6963,	,sm:0.85
step: 1079,	,class_loss:0.0948,	,trans_loss:0.6967,	,sm:0.90
step: 1080,	,class_loss:0.1592,	,trans_loss:0.6973,	,sm:0.88
step: 1081,	,class_loss:0.0666,	,trans_loss:0.6930,	,sm:0.85
step: 1082,	,class_loss:0.1863,	,trans_loss:0.6957,	,sm:0.86
step: 1083,	,class_loss:0.2021,	,trans_loss:0.6954,	,sm:0.93
step: 1084,	,class_loss:0.0888,	,trans_loss:0.6954,	,sm:0.98
step: 1085,	,class_loss:0.0878,	,trans_loss:0.6964,	,sm:0.88
step: 1086,	,class_loss:0.1286,	,trans_loss:0.6965,	,sm:0.90
step: 1087,	,class_loss:0.1387,	,trans_loss:0.6951,	,sm:0.83
step: 1088,	,class_loss:0.0335,	,trans_loss:0.6951,	,sm:0.77
step: 1089,	,class_loss:0.1065,	,trans_loss:0.6964,	,sm:0.78
step: 1090,	,class_loss:0.1203,	,trans_loss:0.6974,	,sm:0.84
step: 1091,	,class_loss:0.1073,	,trans_loss:0.6978,	,sm:0.87
step: 1092,	,class_loss:

step: 1208,	,class_loss:0.1337,	,trans_loss:0.6949,	,sm:0.78
step: 1209,	,class_loss:0.0526,	,trans_loss:0.6977,	,sm:0.81
step: 1210,	,class_loss:0.0812,	,trans_loss:0.6947,	,sm:0.92
step: 1211,	,class_loss:0.0549,	,trans_loss:0.6947,	,sm:0.86
step: 1212,	,class_loss:0.0289,	,trans_loss:0.6944,	,sm:0.79
step: 1213,	,class_loss:0.0561,	,trans_loss:0.6943,	,sm:0.70
step: 1214,	,class_loss:0.0639,	,trans_loss:0.6948,	,sm:0.68
step: 1215,	,class_loss:0.1399,	,trans_loss:0.6939,	,sm:0.79
step: 1216,	,class_loss:0.0905,	,trans_loss:0.6919,	,sm:0.85
step: 1217,	,class_loss:0.0545,	,trans_loss:0.6944,	,sm:0.81
step: 1218,	,class_loss:0.0379,	,trans_loss:0.6964,	,sm:0.76
step: 1219,	,class_loss:0.1231,	,trans_loss:0.6955,	,sm:0.82
step: 1220,	,class_loss:0.0759,	,trans_loss:0.6945,	,sm:0.78
step: 1221,	,class_loss:0.0385,	,trans_loss:0.6963,	,sm:0.73
step: 1222,	,class_loss:0.0602,	,trans_loss:0.6951,	,sm:0.72
step: 1223,	,class_loss:0.0430,	,trans_loss:0.6935,	,sm:0.81
step: 1224,	,class_loss: