Skip to content

Commit

Permalink
DCAN
Browse files Browse the repository at this point in the history
  • Loading branch information
xiebinhui committed Jun 19, 2020
1 parent f1e5e96 commit 5108bfa
Show file tree
Hide file tree
Showing 28 changed files with 607,761 additions and 2 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@ To install the required python packages, run
```pip install -r requirements.txt ```

## Datasets

### Office-Home
Office-Home dataset can be found [here](http://hemanthdv.org/OfficeHome-Dataset/).

### DomainNet

DomainNet dataset can be found [here](http://ai.bu.edu/M3SDA/).

### Office-31
Office-31 dataset can be found [here](https://people.eecs.berkeley.edu/~jhoffman/domainadapt/).


## Pre-trained models
Pre-trained models can be downloaded [here]() and put in <root_dir>/pretrained_models

Pre-trained models can be downloaded [here](https://github.com/BIT-DA/DCAN/releases) and put in <root_dir>/pretrained_models


## Running the code
Expand Down
14,604 changes: 14,604 additions & 0 deletions data/list/domainnet/clipart_test.txt

Large diffs are not rendered by default.

33,525 changes: 33,525 additions & 0 deletions data/list/domainnet/clipart_train.txt

Large diffs are not rendered by default.

15,582 changes: 15,582 additions & 0 deletions data/list/domainnet/infograph_test.txt

Large diffs are not rendered by default.

36,023 changes: 36,023 additions & 0 deletions data/list/domainnet/infograph_train.txt

Large diffs are not rendered by default.

21,850 changes: 21,850 additions & 0 deletions data/list/domainnet/painting_test.txt

Large diffs are not rendered by default.

50,416 changes: 50,416 additions & 0 deletions data/list/domainnet/painting_train.txt

Large diffs are not rendered by default.

51,750 changes: 51,750 additions & 0 deletions data/list/domainnet/quickdraw_test.txt

Large diffs are not rendered by default.

120,750 changes: 120,750 additions & 0 deletions data/list/domainnet/quickdraw_train.txt

Large diffs are not rendered by default.

52,041 changes: 52,041 additions & 0 deletions data/list/domainnet/real_test.txt

Large diffs are not rendered by default.

120,906 changes: 120,906 additions & 0 deletions data/list/domainnet/real_train.txt

Large diffs are not rendered by default.

20,916 changes: 20,916 additions & 0 deletions data/list/domainnet/sketch_test.txt

Large diffs are not rendered by default.

48,212 changes: 48,212 additions & 0 deletions data/list/domainnet/sketch_train.txt

Large diffs are not rendered by default.

2,427 changes: 2,427 additions & 0 deletions data/list/home/Art_65.txt

Large diffs are not rendered by default.

4,365 changes: 4,365 additions & 0 deletions data/list/home/Clipart_65.txt

Large diffs are not rendered by default.

4,439 changes: 4,439 additions & 0 deletions data/list/home/Product_65.txt

Large diffs are not rendered by default.

4,357 changes: 4,357 additions & 0 deletions data/list/home/RealWorld_65.txt

Large diffs are not rendered by default.

2,817 changes: 2,817 additions & 0 deletions data/list/office/amazon_31.txt

Large diffs are not rendered by default.

498 changes: 498 additions & 0 deletions data/list/office/dslr_31.txt

Large diffs are not rendered by default.

795 changes: 795 additions & 0 deletions data/list/office/webcam_31.txt

Large diffs are not rendered by default.

103 changes: 103 additions & 0 deletions data_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-

import numpy as np
from PIL import Image
from torch.utils.data import Dataset


def make_dataset(image_list, labels):
if labels:
len_ = len(image_list)
images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
else:
if len(image_list[0].split()) > 2:
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
else:
images = [(val.split()[0], int(val.split()[1])) for val in image_list]
return images


def rgb_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')


def l_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('L')


class ImageList(Dataset):
def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
# if len(imgs) == 0:
# raise(RuntimeError("Found 0 image in subfolders of: " + root + "\n"
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader

def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.imgs)


class ImageValueList(Dataset):
def __init__(self, image_list, labels=None, transform=None, target_transform=None,
loader=rgb_loader):
imgs = make_dataset(image_list, labels)
# if len(imgs) == 0:
# raise(RuntimeError("Found 0 image in subfolders of: " + root + "\n"
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

self.imgs = imgs
self.values = [1.0] * len(imgs)
self.transform = transform
self.target_transform = target_transform
self.loader = loader

def set_values(self, values):
self.values = values

def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
return len(self.imgs)


class DsetThreeChannels(Dataset):
# Make sure that your dataset actually returns images with only one channel!

def __init__(self, dset):
self.dset = dset

def __getitem__(self, index):
image, label = self.dset[index]
return image.repeat(3, 1, 1), label

def __len__(self):
return len(self.dset)
40 changes: 40 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-

import os
import time
import logging
from logging import handlers


class Logger(object):

level_relations = {
'debug': logging.DEBUG,
'info': logging.INFO,
'warning': logging.WARNING,
'error': logging.ERROR,
'crit': logging.CRITICAL
} # Log level

def __init__(self, logroot, filename, level='info', when='D', fmt='%(message)s'):

if not os.path.exists(logroot):
os.makedirs(logroot)

filename = logroot + time.strftime('%Y-%m-%d %H:%M:%S') + ' ' + filename + '.log'
self.logger = logging.getLogger(filename)
format_str = logging.Formatter(fmt) # Set the log format
self.logger.setLevel(self.level_relations.get(level)) # Set the log level
sh = logging.StreamHandler() # Output to the screen
sh.setFormatter(format_str)

# Write a processor to a file that generates the file automatically at specified intervals
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, encoding='utf-8')
th.setFormatter(format_str)
self.logger.addHandler(sh)
self.logger.addHandler(th)


if __name__ == '__main__':
log = Logger(logroot='log/', filename='test', level='debug')
log.logger.debug('Logger test.')
59 changes: 59 additions & 0 deletions loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
import torch.nn.functional as F
import pdb


def EntropyLoss(input_):
mask = input_.ge(0.0000001)
mask_out = torch.masked_select(input_, mask)
entropy = - (torch.sum(mask_out * torch.log(mask_out)))
return entropy / float(input_.size(0))


def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0 - total1) ** 2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val) # /len(kernel_val)


def MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss = 0
for i in range(batch_size):
s1, s2 = i, (i + 1) % batch_size
t1, t2 = s1 + batch_size, s2 + batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
return loss / float(batch_size)


def MMD_reg(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size_source = int(source.size()[0])
batch_size_target = int(target.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss = 0
for i in range(batch_size_source):
s1, s2 = i, (i + 1) % batch_size_source
t1, t2 = s1 + batch_size_target, s2 + batch_size_target
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
return loss / float(batch_size_source + batch_size_target)
17 changes: 17 additions & 0 deletions lr_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-


def inv_lr_scheduler(optimizer, iter_num, gamma, power, lr=0.001, weight_decay=0.0005):
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
lr = lr * (1 + gamma * iter_num) ** (-power)
i = 0
for param_group in optimizer.param_groups:
param_group['lr'] = lr * param_group['lr_mult']
param_group['weight_decay'] = weight_decay * param_group['decay_mult']
i += 1

return optimizer


schedule_dict = {"inv": inv_lr_scheduler}

0 comments on commit 5108bfa

Please sign in to comment.