-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import numpy as np | ||
from PIL import Image | ||
from torch.utils.data import Dataset | ||
|
||
|
||
def make_dataset(image_list, labels): | ||
if labels: | ||
len_ = len(image_list) | ||
images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] | ||
else: | ||
if len(image_list[0].split()) > 2: | ||
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] | ||
else: | ||
images = [(val.split()[0], int(val.split()[1])) for val in image_list] | ||
return images | ||
|
||
|
||
def rgb_loader(path): | ||
with open(path, 'rb') as f: | ||
with Image.open(f) as img: | ||
return img.convert('RGB') | ||
|
||
|
||
def l_loader(path): | ||
with open(path, 'rb') as f: | ||
with Image.open(f) as img: | ||
return img.convert('L') | ||
|
||
|
||
class ImageList(Dataset): | ||
def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'): | ||
imgs = make_dataset(image_list, labels) | ||
# if len(imgs) == 0: | ||
# raise(RuntimeError("Found 0 image in subfolders of: " + root + "\n" | ||
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) | ||
|
||
self.imgs = imgs | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
if mode == 'RGB': | ||
self.loader = rgb_loader | ||
elif mode == 'L': | ||
self.loader = l_loader | ||
|
||
def __getitem__(self, index): | ||
path, target = self.imgs[index] | ||
img = self.loader(path) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.imgs) | ||
|
||
|
||
class ImageValueList(Dataset): | ||
def __init__(self, image_list, labels=None, transform=None, target_transform=None, | ||
loader=rgb_loader): | ||
imgs = make_dataset(image_list, labels) | ||
# if len(imgs) == 0: | ||
# raise(RuntimeError("Found 0 image in subfolders of: " + root + "\n" | ||
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) | ||
|
||
self.imgs = imgs | ||
self.values = [1.0] * len(imgs) | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.loader = loader | ||
|
||
def set_values(self, values): | ||
self.values = values | ||
|
||
def __getitem__(self, index): | ||
path, target = self.imgs[index] | ||
img = self.loader(path) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.target_transform is not None: | ||
target = self.target_transform(target) | ||
|
||
return img, target | ||
|
||
def __len__(self): | ||
return len(self.imgs) | ||
|
||
|
||
class DsetThreeChannels(Dataset): | ||
# Make sure that your dataset actually returns images with only one channel! | ||
|
||
def __init__(self, dset): | ||
self.dset = dset | ||
|
||
def __getitem__(self, index): | ||
image, label = self.dset[index] | ||
return image.repeat(3, 1, 1), label | ||
|
||
def __len__(self): | ||
return len(self.dset) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import os | ||
import time | ||
import logging | ||
from logging import handlers | ||
|
||
|
||
class Logger(object): | ||
|
||
level_relations = { | ||
'debug': logging.DEBUG, | ||
'info': logging.INFO, | ||
'warning': logging.WARNING, | ||
'error': logging.ERROR, | ||
'crit': logging.CRITICAL | ||
} # Log level | ||
|
||
def __init__(self, logroot, filename, level='info', when='D', fmt='%(message)s'): | ||
|
||
if not os.path.exists(logroot): | ||
os.makedirs(logroot) | ||
|
||
filename = logroot + time.strftime('%Y-%m-%d %H:%M:%S') + ' ' + filename + '.log' | ||
self.logger = logging.getLogger(filename) | ||
format_str = logging.Formatter(fmt) # Set the log format | ||
self.logger.setLevel(self.level_relations.get(level)) # Set the log level | ||
sh = logging.StreamHandler() # Output to the screen | ||
sh.setFormatter(format_str) | ||
|
||
# Write a processor to a file that generates the file automatically at specified intervals | ||
th = handlers.TimedRotatingFileHandler(filename=filename, when=when, encoding='utf-8') | ||
th.setFormatter(format_str) | ||
self.logger.addHandler(sh) | ||
self.logger.addHandler(th) | ||
|
||
|
||
if __name__ == '__main__': | ||
log = Logger(logroot='log/', filename='test', level='debug') | ||
log.logger.debug('Logger test.') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import math | ||
import torch.nn.functional as F | ||
import pdb | ||
|
||
|
||
def EntropyLoss(input_): | ||
mask = input_.ge(0.0000001) | ||
mask_out = torch.masked_select(input_, mask) | ||
entropy = - (torch.sum(mask_out * torch.log(mask_out))) | ||
return entropy / float(input_.size(0)) | ||
|
||
|
||
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): | ||
n_samples = int(source.size()[0]) + int(target.size()[0]) | ||
total = torch.cat([source, target], dim=0) | ||
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) | ||
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) | ||
L2_distance = ((total0 - total1) ** 2).sum(2) | ||
if fix_sigma: | ||
bandwidth = fix_sigma | ||
else: | ||
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) | ||
bandwidth /= kernel_mul ** (kernel_num // 2) | ||
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] | ||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] | ||
return sum(kernel_val) # /len(kernel_val) | ||
|
||
|
||
def MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): | ||
batch_size = int(source.size()[0]) | ||
kernels = guassian_kernel(source, target, | ||
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) | ||
loss = 0 | ||
for i in range(batch_size): | ||
s1, s2 = i, (i + 1) % batch_size | ||
t1, t2 = s1 + batch_size, s2 + batch_size | ||
loss += kernels[s1, s2] + kernels[t1, t2] | ||
loss -= kernels[s1, t2] + kernels[s2, t1] | ||
return loss / float(batch_size) | ||
|
||
|
||
def MMD_reg(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): | ||
batch_size_source = int(source.size()[0]) | ||
batch_size_target = int(target.size()[0]) | ||
kernels = guassian_kernel(source, target, | ||
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) | ||
loss = 0 | ||
for i in range(batch_size_source): | ||
s1, s2 = i, (i + 1) % batch_size_source | ||
t1, t2 = s1 + batch_size_target, s2 + batch_size_target | ||
loss += kernels[s1, s2] + kernels[t1, t2] | ||
loss -= kernels[s1, t2] + kernels[s2, t1] | ||
return loss / float(batch_size_source + batch_size_target) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
|
||
def inv_lr_scheduler(optimizer, iter_num, gamma, power, lr=0.001, weight_decay=0.0005): | ||
"""Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" | ||
lr = lr * (1 + gamma * iter_num) ** (-power) | ||
i = 0 | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = lr * param_group['lr_mult'] | ||
param_group['weight_decay'] = weight_decay * param_group['decay_mult'] | ||
i += 1 | ||
|
||
return optimizer | ||
|
||
|
||
schedule_dict = {"inv": inv_lr_scheduler} | ||
|