Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: brits on cuda #4

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion pypots/__version__.py
Expand Up @@ -21,4 +21,4 @@
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'

version = '0.0.3'
version = '0.0.dev32'
9 changes: 5 additions & 4 deletions pypots/data/dataset_for_brits.py
Expand Up @@ -10,7 +10,7 @@
from pypots.data.base import BaseDataset


def parse_delta(missing_mask):
def parse_delta(missing_mask, device=None):
""" Generate time-gap (delta) matrix from missing masks.

Parameters
Expand All @@ -28,6 +28,7 @@ def parse_delta(missing_mask):
n_samples, n_steps, n_features = missing_mask.shape
device = missing_mask.device
delta_collector = []

for m_mask in missing_mask:
delta = []
for step in range(n_steps):
Expand All @@ -53,17 +54,17 @@ class DatasetForBRITS(BaseDataset):
Classification labels of according time-series samples.
"""

def __init__(self, X, y=None):
def __init__(self, X, y=None, device=None):
super().__init__(X, y)

# calculate all delta here.
# Training will take too much time if we put delta calculation in __getitem__().
forward_missing_mask = (~torch.isnan(X)).type(torch.float32)
forward_X = torch.nan_to_num(X)
forward_delta = parse_delta(forward_missing_mask)
forward_delta = parse_delta(forward_missing_mask, device)
backward_X = torch.flip(forward_X, dims=[1])
backward_missing_mask = torch.flip(forward_missing_mask, dims=[1])
backward_delta = parse_delta(backward_missing_mask)
backward_delta = parse_delta(backward_missing_mask, device)

self.data = {
'forward': {
Expand Down
9 changes: 5 additions & 4 deletions pypots/imputation/base.py
Expand Up @@ -82,16 +82,17 @@ def _train_model(self, training_loader, val_loader=None, val_X_intact=None, val_
try:
for epoch in range(self.epochs):
self.model.train()
epoch_train_loss_collector = []
epoch_train_loss_collector = torch.zeros(len(training_loader), device=self.device)
for idx, data in enumerate(training_loader):
inputs = self.assemble_input_data(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
results['loss'].backward()
self.optimizer.step()
epoch_train_loss_collector.append(results['loss'].item())
with torch.no_grad():
epoch_train_loss_collector[idx] += results['loss']

mean_train_loss = np.mean(epoch_train_loss_collector) # mean training loss of the current epoch
mean_train_loss = torch.mean(epoch_train_loss_collector).item() # mean training loss of the current epoch
self.logger['training_loss'].append(mean_train_loss)

if val_loader is not None:
Expand Down Expand Up @@ -139,7 +140,7 @@ def _train_model(self, training_loader, val_loader=None, val_X_intact=None, val_
'Model will load the best parameters so far for testing. '
"If you don't want it, please try fit() again.")

if np.equal(self.best_loss, float('inf')):
if self.best_loss == float('inf'):
raise ValueError('Something is wrong. best_loss is Nan after training.')

print('Finished training.')
6 changes: 3 additions & 3 deletions pypots/imputation/brits.py
Expand Up @@ -491,15 +491,15 @@ def fit(self, train_X, val_X=None):
if val_X is not None:
val_X = self.check_input(self.n_steps, self.n_features, val_X)

training_set = DatasetForBRITS(train_X) # time_gaps is necessary for BRITS
training_set = DatasetForBRITS(train_X, device=self.device) # time_gaps is necessary for BRITS
training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True)

if val_X is None:
self._train_model(training_loader)
else:
val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar(val_X, 0.2)
val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan)
val_set = DatasetForBRITS(val_X)
val_set = DatasetForBRITS(val_X, device=self.device)
val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False)
self._train_model(training_loader, val_loader, val_X_intact, val_X_indicating_mask)

Expand Down Expand Up @@ -542,7 +542,7 @@ def assemble_input_data(self, data):
def impute(self, X):
X = self.check_input(self.n_steps, self.n_features, X)
self.model.eval() # set the model as eval status to freeze it.
test_set = DatasetForBRITS(X)
test_set = DatasetForBRITS(X, device=self.device)
test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False)
imputation_collector = []

Expand Down
14 changes: 7 additions & 7 deletions pypots/tests/test_imputation.py
Expand Up @@ -95,25 +95,25 @@ def setUp(self) -> None:
self.test_X_intact = DATA['test_X_intact']
self.test_X_indicating_mask = DATA['test_X_indicating_mask']
print('Running test cases for BRITS...')
self.brits = BRITS(DATA['n_steps'], DATA['n_features'], 256, epochs=EPOCH)
self.brits = BRITS(DATA['n_steps'], DATA['n_features'], 10, epochs=EPOCH)
self.brits.fit(self.train_X, self.val_X)

def test_parameters(self):
assert (hasattr(self.brits, 'model')
self.assertTrue(hasattr(self.brits, 'model')
and self.brits.model is not None)

assert (hasattr(self.brits, 'optimizer')
self.assertTrue(hasattr(self.brits, 'optimizer')
and self.brits.optimizer is not None)

assert hasattr(self.brits, 'best_loss')
self.assertTrue(hasattr(self.brits, 'best_loss'))
self.assertNotEqual(self.brits.best_loss, float('inf'))

assert (hasattr(self.brits, 'best_model_dict')
self.assertTrue(hasattr(self.brits, 'best_model_dict')
and self.brits.best_model_dict is not None)

def test_impute(self):
imputed_X = self.brits.impute(self.test_X)
assert not np.isnan(imputed_X).any(), 'Output still has missing values after running impute().'
self.assertFalse(np.isnan(imputed_X).any()), 'Output still has missing values after running impute().'
test_MAE = cal_mae(imputed_X, self.test_X_intact, self.test_X_indicating_mask)
print(f'BRITS test_MAE: {test_MAE}')

Expand All @@ -134,7 +134,7 @@ def test_parameters(self):

def test_impute(self):
test_X_imputed = self.locf.impute(self.test_X)
assert not np.isnan(test_X_imputed).any(), 'Output still has missing values after running impute().'
self.assertFalse(np.isnan(test_X_imputed).any()) # 'Output still has missing values after running impute().'
test_MAE = cal_mae(test_X_imputed, self.test_X_intact, self.test_X_indicating_mask)
print(f'LOCF test_MAE: {test_MAE}')

Expand Down