Skip to content

Commit

Permalink
Merge pull request #41 from Media-Smart/sampler
Browse files Browse the repository at this point in the history
fix: sampler: balance sampler, train_runner: load checkpoint
  • Loading branch information
hxcai committed Oct 20, 2020
2 parents d9a4c59 + 67e8f23 commit 68d4f2e
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ workdir/
*.pth

# local file
tmp.py
*.txt
train.sh
ttt.py
Expand Down
76 changes: 58 additions & 18 deletions vedastr/dataloaders/samplers/balance_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import logging
import random

import numpy as np
from torch.utils.data import Sampler

from .registry import SAMPLER
Expand All @@ -21,19 +23,52 @@ class BalanceSampler(Sampler):
downsample (bool): Set to True to downsample bigger sampler set.
.. warning:: If both oversample and downsample is True, BanlanceSampler will do
oversample first. That means downsample will do no effect.
The last batch mey have different batch ratio, which means last batch may be
not balance.
"""

def __init__(self, dataset, batch_size, shuffle, oversample=False, downsample=False):
def __init__(self, dataset, batch_size, shuffle, oversample=False, downsample=False, eps=0.1):
assert hasattr(dataset, 'data_range')
assert hasattr(dataset, 'batch_ratio')
self.dataset = dataset
self.samples_range = dataset.data_range
self.batch_ratio = dataset.batch_ratio
self.batch_ratio = np.array(dataset.batch_ratio)
self.batch_size = batch_size
self.batch_sizes = self._compute_each_batch_size()
new_br = self.batch_sizes / self.batch_size
br_diffs = np.abs((new_br - self.batch_ratio))
assert not np.sum(br_diffs > eps), "After computing the batch sizes of each dataset based on" \
"given batch ratio, the max difference between new batch ratio " \
"which compute based on the computed batch size and" \
f" given batch ratio is large than the eps {eps}.\n" \
"Please Considering increase the value of eps or batch size." \
f"Current computed batch sizes are {self.batch_sizes}, new batch " \
f"ratios are {new_br}, while give batch ratio" \
f" are {self.batch_ratio}.\n" \
f"The max difference between given batch ratio and new batch ratio" \
f"is {np.max(np.array(br_diffs))}."

assert 0 not in self.batch_sizes, "0 batch size is not supported, where batch " \
"size is computed based on the batch ratio." \
f" Computed batch size is {self.batch_sizes}."

assert np.sum(self.batch_sizes) == self.batch_size
logging.info(f"The truly used batch ratios are {new_br}")
self.batch_ratio = new_br
self.oversample = oversample
self.downsample = downsample
self.shuffle = shuffle
self._generate_indices_()
self._generate_indices()

def _compute_each_batch_size(self):
batch_sizes = self.batch_ratio * self.batch_size
int_bs = batch_sizes.astype(np.int)
float_bs = (batch_sizes - int_bs) >= 0.5
diff = self.batch_size - np.sum(int_bs) - np.sum(float_bs)
float_bs[np.where(float_bs == (diff < 0))[0][:int(abs(diff))]] = (diff >= 0)

return (int_bs + float_bs).astype(np.int)

@property
def _num_samples(self):
Expand All @@ -43,7 +78,7 @@ def _num_samples(self):
def _num_samples(self, v):
self.num_samples = v

def _generate_indices_(self):
def _generate_indices(self):
self._num_samples = len(self.dataset)
indices_ = []
# TODO, elegant
Expand All @@ -58,31 +93,36 @@ def _generate_indices_(self):
if self.shuffle:
random.shuffle(temp)
indices_.append(temp)
per_dataset_len = [len(index) for index in indices_]
pratios = [l / s for (l, s) in zip(per_dataset_len, self.batch_sizes)]
if self.oversample:
indices_ = self._oversample(indices_)
need_len = [int(np.ceil(max(pratios) * size)) for size in self.batch_sizes]
indices_ = self._oversample(indices_, need_len)
if self.downsample:
indices_ = self._downsample(indices_)
need_len = [int(np.ceil(min(pratios) * size)) for size in self.batch_sizes]
indices_ = self._downsample(indices_, need_len)
return indices_

def __iter__(self):
indices_ = self._generate_indices_()
indices_ = self._generate_indices()
total_nums = len(self) // self.batch_size
sizes = [int(self.batch_size * br) for br in self.batch_ratio]
final_index = [total_nums * size for size in sizes]
final_index = [total_nums * size for size in self.batch_sizes]
indices = []
for idx2 in range(total_nums):
for idx3, size in enumerate(sizes):
for idx3, size in enumerate(self.batch_sizes):
indices += indices_[idx3][idx2 * size:(idx2 + 1) * size]
# TODO,
# oversample or drop last. In current situation,
# the performance may drop a lot because the last batch may not balance
for idx4, index in enumerate(final_index):
indices += indices_[idx4][index:]
return iter(indices)

def _oversample(self, indices):
max_len = max([len(index) for index in indices])
def _oversample(self, indices, need_len):
result_indices = []
for idx, index in enumerate(indices):
current_nums = len(index)
need_num = max_len - current_nums
need_num = need_len[idx] - current_nums
total_nums = need_num // current_nums
mod_nums = need_num % current_nums
init_index = copy.copy(index)
Expand All @@ -93,16 +133,16 @@ def _oversample(self, indices):
index += new_index
index += random.sample(index, mod_nums)
result_indices.append(index)
self._num_samples = max_len * len(indices)
self._num_samples = np.sum(need_len)

return result_indices

def _downsample(self, indices):
min_len = min([len(index) for index in indices])
def _downsample(self, indices, need_len):
result_indices = []
for idx, index in enumerate(indices):
index = random.sample(index, min_len)
index = random.sample(index, need_len[idx])
result_indices.append(index)
self._num_samples = min_len * len(indices)
self._num_samples = np.sum(need_len)
return result_indices

def __len__(self):
Expand Down
19 changes: 4 additions & 15 deletions vedastr/models/bodies/rectificators/spin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
# Not fully implemented yet.
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .registry import RECTIFICATORS
from vedastr.models.bodies.feature_extractors import build_feature_extractor
Expand Down Expand Up @@ -47,11 +48,8 @@ def __init__(self, spin, k):
super(SPIN, self).__init__()
self.body = build_feature_extractor(spin['feature_extractor'])
self.spn = SPN(spin['spn'])
self.ain = AIN(spin['ain'])
self.betas = generate_beta(k)
init_weights(self.modules())
# self.spn.head[-1].fc.weight.data.fill_(0)
# self.spn.head[-1].fc.bias.data = torch.from_numpy(np.array([0, 0, 0, 0, 0,0,1,0,0, 0, 0, 0, 0, 0])).float()

def forward(self, x):
b, c, h, w = x.size()
Expand All @@ -60,23 +58,14 @@ def forward(self, x):
x = self.body(x)

spn_out = self.spn(x) # 2k+2
ain_out = self.ain(x) # activated by sigmoid

omega = spn_out[:, :-1]
g_out = init_img.requires_grad_(True)

alpha = F.sigmoid(spn_out[:, -1])

# offset
ain_out = F.interpolate(ain_out, size=(h, w), mode='bilinear') # noqa: F811
g_out = alpha[:, None, None, None] * ain_out + (1 - alpha[:, None, None, None]) * init_img
# g_out = alpha * ain_out + (1 - alpha) * init_img
# g_out = init_img
# beta dist on g_out
gamma_out = [g_out ** beta for beta in self.betas]
gamma_out = torch.stack(gamma_out, axis=1).requires_grad_(True)

fusion_img = omega[:, :, None, None, None] * gamma_out
fusion_img = torch.sigmoid(fusion_img.sum(dim=1))

return fusion_img


Expand Down
13 changes: 3 additions & 10 deletions vedastr/runners/inference_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load_checkpoint(self, filename, map_location='default', strict=True):

return load_checkpoint(self.model, filename, map_location, strict)

def postprocess(self, preds, cfg=None, label=None):
def postprocess(self, preds, cfg=None):
if cfg is not None:
sensitive = cfg.get('sensitive', True)
character = cfg.get('character', '')
Expand All @@ -65,27 +65,20 @@ def postprocess(self, preds, cfg=None, label=None):
max_probs, indexes = probs.max(dim=2)
preds_str = []
preds_prob = []
labels = []
for i, pstr in enumerate(self.converter.decode(indexes)):
str_len = len(pstr)
if str_len == 0:
prob = 0
else:
prob = max_probs[i, :str_len].cumprod(dim=0)[-1]
preds_prob.append(prob)

if not sensitive:
pstr = pstr.lower()
if label is not None:
tmp = label[i].lower()

if character:
pstr = re.sub('[^{}]'.format(character), '', pstr)
if label is not None:
tmp = re.sub('[^{}]'.format(character), '', tmp)
labels.append(tmp)

preds_str.append(pstr)
if label is not None:
return preds_str, preds_prob, labels

return preds_str, preds_prob

Expand Down
4 changes: 2 additions & 2 deletions vedastr/runners/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_batch(self, img, label):
else:
pred = self.model((img,))

pred, prob, label = self.postprocess(pred, self.postprocess_cfg, label)
pred, prob = self.postprocess(pred, self.postprocess_cfg)
self.metric.measure(pred, prob, label)
self.backup_metric.measure(pred, prob, label)

Expand All @@ -41,4 +41,4 @@ def __call__(self):
name, self.backup_metric.avg['acc']['true'], self.metric.avg['edit']
))
self.logger.info('Test, average acc %.4f, edit distance %s' % (self.metric.avg['acc']['true'],
self.metric.avg['edit']))
self.metric.avg['edit']))
5 changes: 3 additions & 2 deletions vedastr/runners/train_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ def save_model(self,
meta=meta)

def resume(self, checkpoint, resume_optimizer=False,
resume_lr_scheduler=False, resume_meta=False,
resume_lr_scheduler=False, resume_meta=False, strict=True,
map_location='default'):
checkpoint = self.load_checkpoint(checkpoint,
map_location=map_location)
map_location=map_location,
strict=strict)

if resume_optimizer and 'optimizer' in checkpoint:
self.logger.info('Resume optimizer')
Expand Down

0 comments on commit 68d4f2e

Please sign in to comment.