In [1]:
import argparse
import os
import os.path as osp

import numpy as np
import torch
print(torch.cuda.is_available())

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import network
import loss
import pre_process as prep
import lr_schedule
from pre_process import ImageList, image_classification_test
import copy
import random



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Code for RSDA-MSTN')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='1', help="device id to run")
    parser.add_argument('--source', type=str, default='amazon',choices=["amazon", "dslr","webcam"])
    parser.add_argument('--target', type=str, default='dslr', choices=["amazon", "dslr", "webcam"])
    parser.add_argument('--test_interval', type=int, default=50, help="interval of two continuous test phase")
    parser.add_argument('--snapshot_interval', type=int, default=1000, help="interval of two continuous output model")
    parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
    parser.add_argument('--stages', type=int, default=6, help="training stages")
    args = parser.parse_args([])
    s_dset_path = 'data/office/' + args.source + '_list.txt' #'../../data/office/' + args.source + '_list.txt'
    t_dset_path = 'data/office/' + args.target + '_list.txt' #'../../data/office/' + args.target + '_list.txt'

    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    config = {}
    config["source"] = args.source
    config["target"] = args.target
    config["gpu"] = args.gpu_id
    config["test_interval"] = args.test_interval
    config["snapshot_interval"] = args.snapshot_interval
    config["output_for_test"] = True
    config["output_path"] = "snapshot/init"
    if not osp.exists(config["output_path"]):
        os.makedirs(config["output_path"])
    config["out_file"] = open(osp.join(config["output_path"],args.source+"_"+args.target+ "class_log.txt"), "w")

    config["prep"] = {'params':{"resize_size":256, "crop_size":224}}
    config["network"] = {"name":network.ResNet50, \
            "params":{"new_cls":True,"feature_dim":256,"class_num":31} }
    config["optimizer"] = {"type":optim.SGD, "optim_params":{'lr':args.lr, "momentum":0.9, \
                           "weight_decay":0.0005, "nesterov":True}, "lr_type":"inv", \
                           "lr_param":{"lr":args.lr, "gamma":0.001, "power":0.75} }
    config["data"] = {"source":{"list_path":s_dset_path, "batch_size":18}, \
                      "target":{"list_path":t_dset_path, "batch_size":18}, \
                      "test":{"list_path":t_dset_path, "batch_size":36}}
    config["out_file"].flush()
    if config["source"] == "amazon" and config["target"] == "dslr":
        config["iterations"] = 2000
        seed = 0
    elif config["source"] == "amazon" and config["target"] == "webcam":
        config["iterations"] = 2000
        seed = 0
    else:
        config["iterations"] = 4000
        seed = 1


    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

    config["out_file"].write('\n--- initialization ---\n')
    source = config["source"]
    target = config["target"]
    prep_dict = {}
    prep_dict["source"] = prep.image_train(**config["prep"]['params'])
    prep_dict["target"] = prep.image_train(**config["prep"]['params'])

    prep_dict["test"] = prep.image_test(**config["prep"]['params'])

    ## prepare data
    dsets = {}
    dset_loaders = {}
    data_config = config["data"]
    train_bs = data_config["source"]["batch_size"]
    test_bs = data_config["test"]["batch_size"]
    dsets["source"] = ImageList(open(data_config["source"]["list_path"]).readlines(), \
                                transform=prep_dict["source"])
    dset_loaders["source"] = DataLoader(dsets["source"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)
    dsets["target"] = ImageList(open(data_config["target"]["list_path"]).readlines(), \
                                transform=prep_dict["target"])
    dset_loaders["target"] = DataLoader(dsets["target"], batch_size=train_bs, \
                                        shuffle=True, num_workers=0, drop_last=True)

    dsets["test"] = ImageList(open(data_config["test"]["list_path"]).readlines(), \
                              transform=prep_dict["test"])
    dset_loaders["test"] = DataLoader(dsets["test"], batch_size=test_bs, \
                                      shuffle=False, num_workers=0)

    class_num = config["network"]["params"]["class_num"]

    ## set base network
    net_config = config["network"]
    base_network = net_config["name"](**net_config["params"])
    base_network = base_network.cuda()

    ## add additional network for some methods
    ad_net = network.AdversarialNetwork(base_network.output_num(), 1024)
    ad_net = ad_net.cuda()

    gpus = config['gpu'].split(',')
    if len(gpus) > 1:
        ad_net = nn.DataParallel(ad_net)
        base_network = nn.DataParallel(base_network)

    parameter_classifier = [base_network.get_parameters()[1]]
    parameter_feature = base_network.get_parameters()[0:1] + ad_net.get_parameters()

    ## set optimizer
    optimizer_config = config["optimizer"]
    optimizer_classfier = optimizer_config["type"](parameter_classifier, \
                                                   **(optimizer_config["optim_params"]))
    optimizer_feature = optimizer_config["type"](parameter_feature, \
                                                 **(optimizer_config["optim_params"]))
    param_lr = []
    for param_group in optimizer_feature.param_groups:
        param_lr.append(param_group["lr"])
    param_lr.append(optimizer_classfier.param_groups[0]["lr"])
    schedule_param = optimizer_config["lr_param"]
    lr_scheduler = lr_schedule.schedule_dict[optimizer_config["lr_type"]]

    ## train
    len_train_source = len(dset_loaders["source"])
    len_train_target = len(dset_loaders["target"])
    best_acc = 0.0
    best_model = copy.deepcopy(base_network)

    Cs_memory=torch.zeros(class_num,256).cuda()
    Ct_memory=torch.zeros(class_num,256).cuda()


    for i in range(config["iterations"]):
        if i % config["test_interval"] == config["test_interval"] - 1:
            base_network.train(False)
            temp_acc = image_classification_test(dset_loaders,base_network)
            temp_model = base_network
            if temp_acc > best_acc:
                best_acc = temp_acc
                best_model = copy.deepcopy(temp_model)
            log_str = "iter: {:05d}, \t precision: {:.4f},\t best_acc:{:.4f}".format(i, temp_acc, best_acc)
            config["out_file"].write(log_str + "\n")
            config["out_file"].flush()
            print(log_str)
        if (i + 1) % config["snapshot_interval"] == 0:
            if not os.path.exists("save/init_model"):
                os.makedirs("save/init_model")
            torch.save(best_model, 'save/init_model/' + source + '_' + target + 'class.pkl')

        ## train one iter
        base_network.train(True)
        ad_net.train(True)
        optimizer_classfier = lr_scheduler(optimizer_classfier, i, **schedule_param)
        optimizer_feature = lr_scheduler(optimizer_feature, i, **schedule_param)

        if i % len_train_source == 0:
            iter_source = iter(dset_loaders["source"])
        if i % len_train_target == 0:
            iter_target = iter(dset_loaders["target"])
        inputs_source, labels_source = iter_source.next()
        inputs_target, labels_target = iter_target.next()
        inputs_source, inputs_target, labels_source = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda()
        features_source, outputs_source = base_network(inputs_source)

        classifier_loss = nn.CrossEntropyLoss()(outputs_source, labels_source)

        loss_total = classifier_loss

        optimizer_classfier.zero_grad()
        optimizer_feature.zero_grad()

        loss_total.backward()
        optimizer_feature.step()
        optimizer_classfier.step()

        print('step:{: d},\t,class_loss:{:.4f},\t'
              ''.format(i, classifier_loss.item()))
        Cs_memory.detach_()
        Ct_memory.detach_()

True
step: 0,	,class_loss:3.6211,	
step: 1,	,class_loss:3.5711,	
step: 2,	,class_loss:3.4839,	
step: 3,	,class_loss:3.3255,	
step: 4,	,class_loss:3.6087,	
step: 5,	,class_loss:3.3213,	
step: 6,	,class_loss:3.5033,	
step: 7,	,class_loss:3.5279,	
step: 8,	,class_loss:3.3490,	
step: 9,	,class_loss:3.5655,	
step: 10,	,class_loss:3.4399,	
step: 11,	,class_loss:3.4043,	
step: 12,	,class_loss:3.4197,	
step: 13,	,class_loss:3.3221,	
step: 14,	,class_loss:3.4284,	
step: 15,	,class_loss:3.2974,	
step: 16,	,class_loss:3.2481,	
step: 17,	,class_loss:3.2491,	
step: 18,	,class_loss:3.0499,	
step: 19,	,class_loss:3.2074,	
step: 20,	,class_loss:3.1360,	
step: 21,	,class_loss:3.1376,	
step: 22,	,class_loss:2.9847,	
step: 23,	,class_loss:3.3060,	
step: 24,	,class_loss:3.0913,	
step: 25,	,class_loss:3.1680,	
step: 26,	,class_loss:3.0748,	
step: 27,	,class_loss:3.1650,	
step: 28,	,class_loss:2.9017,	
step: 29,	,class_loss:3.0316,	
step: 30,	,class_loss:3.0308,	
step: 31,	,class_loss:2.9781,	
step: 32,	,cl

step: 252,	,class_loss:0.4355,	
step: 253,	,class_loss:0.9386,	
step: 254,	,class_loss:0.7386,	
step: 255,	,class_loss:1.3248,	
step: 256,	,class_loss:1.4425,	
step: 257,	,class_loss:0.7961,	
step: 258,	,class_loss:0.6981,	
step: 259,	,class_loss:0.8180,	
step: 260,	,class_loss:0.5777,	
step: 261,	,class_loss:0.9423,	
step: 262,	,class_loss:0.8110,	
step: 263,	,class_loss:0.6615,	
step: 264,	,class_loss:0.7535,	
step: 265,	,class_loss:0.6083,	
step: 266,	,class_loss:0.4698,	
step: 267,	,class_loss:0.5188,	
step: 268,	,class_loss:0.7607,	
step: 269,	,class_loss:0.9340,	
step: 270,	,class_loss:0.7027,	
step: 271,	,class_loss:0.7894,	
step: 272,	,class_loss:0.4854,	
step: 273,	,class_loss:0.8080,	
step: 274,	,class_loss:0.6794,	
step: 275,	,class_loss:1.0757,	
step: 276,	,class_loss:0.6576,	
step: 277,	,class_loss:0.7945,	
step: 278,	,class_loss:1.0014,	
step: 279,	,class_loss:0.9217,	
step: 280,	,class_loss:0.3704,	
step: 281,	,class_loss:0.3575,	
step: 282,	,class_loss:0.4505,	
step: 28

step: 501,	,class_loss:0.3628,	
step: 502,	,class_loss:0.4334,	
step: 503,	,class_loss:0.6476,	
step: 504,	,class_loss:0.1916,	
step: 505,	,class_loss:0.3245,	
step: 506,	,class_loss:0.3387,	
step: 507,	,class_loss:0.6786,	
step: 508,	,class_loss:0.2499,	
step: 509,	,class_loss:0.3741,	
step: 510,	,class_loss:0.5270,	
step: 511,	,class_loss:0.2662,	
step: 512,	,class_loss:0.3192,	
step: 513,	,class_loss:0.2665,	
step: 514,	,class_loss:0.2344,	
step: 515,	,class_loss:0.4299,	
step: 516,	,class_loss:0.3148,	
step: 517,	,class_loss:0.5271,	
step: 518,	,class_loss:0.1887,	
step: 519,	,class_loss:0.1865,	
step: 520,	,class_loss:0.3826,	
step: 521,	,class_loss:0.4122,	
step: 522,	,class_loss:0.4862,	
step: 523,	,class_loss:0.4347,	
step: 524,	,class_loss:0.5995,	
step: 525,	,class_loss:0.6023,	
step: 526,	,class_loss:0.2890,	
step: 527,	,class_loss:0.2143,	
step: 528,	,class_loss:0.1981,	
step: 529,	,class_loss:0.5415,	
step: 530,	,class_loss:0.2809,	
step: 531,	,class_loss:0.2010,	
step: 53

step: 750,	,class_loss:0.2242,	
step: 751,	,class_loss:0.4814,	
step: 752,	,class_loss:0.4077,	
step: 753,	,class_loss:0.2368,	
step: 754,	,class_loss:0.2348,	
step: 755,	,class_loss:0.1341,	
step: 756,	,class_loss:0.2043,	
step: 757,	,class_loss:0.7372,	
step: 758,	,class_loss:0.2811,	
step: 759,	,class_loss:0.3056,	
step: 760,	,class_loss:0.5513,	
step: 761,	,class_loss:0.1652,	
step: 762,	,class_loss:0.2404,	
step: 763,	,class_loss:0.5008,	
step: 764,	,class_loss:0.1065,	
step: 765,	,class_loss:0.1975,	
step: 766,	,class_loss:0.2561,	
step: 767,	,class_loss:0.3225,	
step: 768,	,class_loss:0.6613,	
step: 769,	,class_loss:0.1474,	
step: 770,	,class_loss:0.1534,	
step: 771,	,class_loss:0.2637,	
step: 772,	,class_loss:0.3502,	
step: 773,	,class_loss:0.4227,	
step: 774,	,class_loss:0.1324,	
step: 775,	,class_loss:0.4348,	
step: 776,	,class_loss:0.2685,	
step: 777,	,class_loss:0.4049,	
step: 778,	,class_loss:0.4275,	
step: 779,	,class_loss:0.1435,	
step: 780,	,class_loss:0.1067,	
step: 78

step: 999,	,class_loss:0.1114,	
step: 1000,	,class_loss:0.2702,	
step: 1001,	,class_loss:0.2015,	
step: 1002,	,class_loss:0.2125,	
step: 1003,	,class_loss:0.2203,	
step: 1004,	,class_loss:0.0639,	
step: 1005,	,class_loss:0.0713,	
step: 1006,	,class_loss:0.0511,	
step: 1007,	,class_loss:0.2296,	
step: 1008,	,class_loss:0.0798,	
step: 1009,	,class_loss:0.1187,	
step: 1010,	,class_loss:0.0250,	
step: 1011,	,class_loss:0.2705,	
step: 1012,	,class_loss:0.0991,	
step: 1013,	,class_loss:0.1491,	
step: 1014,	,class_loss:0.0572,	
step: 1015,	,class_loss:0.1819,	
step: 1016,	,class_loss:0.1397,	
step: 1017,	,class_loss:0.2009,	
step: 1018,	,class_loss:0.0953,	
step: 1019,	,class_loss:0.1654,	
step: 1020,	,class_loss:0.1834,	
step: 1021,	,class_loss:0.1299,	
step: 1022,	,class_loss:0.0870,	
step: 1023,	,class_loss:0.2833,	
step: 1024,	,class_loss:0.4132,	
step: 1025,	,class_loss:0.1197,	
step: 1026,	,class_loss:0.0625,	
step: 1027,	,class_loss:0.0439,	
step: 1028,	,class_loss:0.3241,	
step: 1029,

step: 1242,	,class_loss:0.1683,	
step: 1243,	,class_loss:0.1423,	
step: 1244,	,class_loss:0.1099,	
step: 1245,	,class_loss:0.3299,	
step: 1246,	,class_loss:0.1544,	
step: 1247,	,class_loss:0.0367,	
step: 1248,	,class_loss:0.0756,	
iter: 01249, 	 precision: 0.7610,	 best_acc:0.7871
step: 1249,	,class_loss:0.0950,	
step: 1250,	,class_loss:0.0810,	
step: 1251,	,class_loss:0.1501,	
step: 1252,	,class_loss:0.0305,	
step: 1253,	,class_loss:0.0737,	
step: 1254,	,class_loss:0.0389,	
step: 1255,	,class_loss:0.0724,	
step: 1256,	,class_loss:0.3096,	
step: 1257,	,class_loss:0.1260,	
step: 1258,	,class_loss:0.1345,	
step: 1259,	,class_loss:0.0883,	
step: 1260,	,class_loss:0.0262,	
step: 1261,	,class_loss:0.0464,	
step: 1262,	,class_loss:0.1968,	
step: 1263,	,class_loss:0.1003,	
step: 1264,	,class_loss:0.0661,	
step: 1265,	,class_loss:0.1603,	
step: 1266,	,class_loss:0.1573,	
step: 1267,	,class_loss:0.1114,	
step: 1268,	,class_loss:0.2406,	
step: 1269,	,class_loss:0.1104,	
step: 1270,	,class_loss:0

step: 1483,	,class_loss:0.0657,	
step: 1484,	,class_loss:0.1427,	
step: 1485,	,class_loss:0.1874,	
step: 1486,	,class_loss:0.1691,	
step: 1487,	,class_loss:0.0541,	
step: 1488,	,class_loss:0.0706,	
step: 1489,	,class_loss:0.1566,	
step: 1490,	,class_loss:0.0615,	
step: 1491,	,class_loss:0.0557,	
step: 1492,	,class_loss:0.1868,	
step: 1493,	,class_loss:0.0905,	
step: 1494,	,class_loss:0.0937,	
step: 1495,	,class_loss:0.0635,	
step: 1496,	,class_loss:0.2347,	
step: 1497,	,class_loss:0.0327,	
step: 1498,	,class_loss:0.1114,	
iter: 01499, 	 precision: 0.7671,	 best_acc:0.7871
step: 1499,	,class_loss:0.1085,	
step: 1500,	,class_loss:0.0938,	
step: 1501,	,class_loss:0.0709,	
step: 1502,	,class_loss:0.0685,	
step: 1503,	,class_loss:0.0517,	
step: 1504,	,class_loss:0.0450,	
step: 1505,	,class_loss:0.1911,	
step: 1506,	,class_loss:0.0752,	
step: 1507,	,class_loss:0.1525,	
step: 1508,	,class_loss:0.1958,	
step: 1509,	,class_loss:0.1182,	
step: 1510,	,class_loss:0.0267,	
step: 1511,	,class_loss:0

step: 1724,	,class_loss:0.1179,	
step: 1725,	,class_loss:0.0518,	
step: 1726,	,class_loss:0.0584,	
step: 1727,	,class_loss:0.0207,	
step: 1728,	,class_loss:0.0393,	
step: 1729,	,class_loss:0.0628,	
step: 1730,	,class_loss:0.0435,	
step: 1731,	,class_loss:0.0324,	
step: 1732,	,class_loss:0.0334,	
step: 1733,	,class_loss:0.0839,	
step: 1734,	,class_loss:0.0789,	
step: 1735,	,class_loss:0.0350,	
step: 1736,	,class_loss:0.0435,	
step: 1737,	,class_loss:0.0286,	
step: 1738,	,class_loss:0.0099,	
step: 1739,	,class_loss:0.1105,	
step: 1740,	,class_loss:0.1710,	
step: 1741,	,class_loss:0.0611,	
step: 1742,	,class_loss:0.1094,	
step: 1743,	,class_loss:0.0491,	
step: 1744,	,class_loss:0.1142,	
step: 1745,	,class_loss:0.1175,	
step: 1746,	,class_loss:0.3527,	
step: 1747,	,class_loss:0.0728,	
step: 1748,	,class_loss:0.2496,	
iter: 01749, 	 precision: 0.7831,	 best_acc:0.7871
step: 1749,	,class_loss:0.1530,	
step: 1750,	,class_loss:0.0524,	
step: 1751,	,class_loss:0.1089,	
step: 1752,	,class_loss:0

step: 1965,	,class_loss:0.0494,	
step: 1966,	,class_loss:0.0725,	
step: 1967,	,class_loss:0.0240,	
step: 1968,	,class_loss:0.0105,	
step: 1969,	,class_loss:0.0234,	
step: 1970,	,class_loss:0.0235,	
step: 1971,	,class_loss:0.0490,	
step: 1972,	,class_loss:0.1053,	
step: 1973,	,class_loss:0.1139,	
step: 1974,	,class_loss:0.0138,	
step: 1975,	,class_loss:0.1671,	
step: 1976,	,class_loss:0.1253,	
step: 1977,	,class_loss:0.0225,	
step: 1978,	,class_loss:0.0576,	
step: 1979,	,class_loss:0.0452,	
step: 1980,	,class_loss:0.0389,	
step: 1981,	,class_loss:0.0570,	
step: 1982,	,class_loss:0.1941,	
step: 1983,	,class_loss:0.0798,	
step: 1984,	,class_loss:0.1091,	
step: 1985,	,class_loss:0.0201,	
step: 1986,	,class_loss:0.0490,	
step: 1987,	,class_loss:0.0447,	
step: 1988,	,class_loss:0.1476,	
step: 1989,	,class_loss:0.0374,	
step: 1990,	,class_loss:0.1416,	
step: 1991,	,class_loss:0.0480,	
step: 1992,	,class_loss:0.0353,	
step: 1993,	,class_loss:0.0521,	
step: 1994,	,class_loss:0.0361,	
step: 1995