Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added CPU support #23

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/mlm/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,13 +558,23 @@ def __init__(self, *args, **kwargs):
raise ValueError("Language was not set but this model uses language embeddings!")

### PyTorch-based
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._device = torch.device("cuda:0" if (torch.cuda.is_available() and self._ctxs != [mx.cpu()]) else "cpu")
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
# TODO: This does not restrict to specific GPUs however, use CUDA_VISIBLE_DEVICES?
# TODO: It also unnecessarily locks the GPUs to each other
self._model.to(self._device)
self._model = torch.nn.DataParallel(self._model, device_ids=[0])

# Use DataParallel to use multiple GPUs if available
if self._device.type == "cuda" and torch.cuda.device_count() > 1:
logging.info("Using {} GPUs!".format(torch.cuda.device_count()))
self._model = torch.nn.DataParallel(self._model, device_ids=[0])
else:
if self._device.type == "cuda":
logging.info("Using 1 GPU!")
else:
logging.info("Using CPU!")

self._model.eval()


Expand Down Expand Up @@ -697,7 +707,7 @@ def sum_accumulated_scores():

for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions, token_masked_ids, normalization) in enumerate((batch,)):

ctx = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ctx = self._device
batch_size += sent_idxs.shape[0]

# TODO: Super inefficient where we go from MXNet to NumPy to PyTorch
Expand All @@ -715,17 +725,25 @@ def sum_accumulated_scores():
token_masked_ids = token_masked_ids.to(ctx)

split_size = token_ids.shape[0]

# Check if we are using DataParallel
if hasattr(self._model, 'module'):
# Using DataParallel, so we need to access the underlying model
model = self._model.module
else:
# Not using DataParallel
model = self._model

if isinstance(self._model.module, AlbertForMaskedLMOptimized) or \
isinstance(self._model.module, BertForMaskedLMOptimized) or \
isinstance(self._model.module, DistilBertForMaskedLMOptimized):
if isinstance(model, AlbertForMaskedLMOptimized) or \
isinstance(model, BertForMaskedLMOptimized) or \
isinstance(model, DistilBertForMaskedLMOptimized):
# Because BERT does not take a length parameter
alen = torch.arange(token_ids.shape[1], dtype=torch.long)
alen = alen.to(ctx)
mask = alen < valid_length[:, None]
out = self._model(input_ids=token_ids, attention_mask=mask, select_positions=masked_positions)
out = out[0].squeeze()
elif isinstance(self._model.module, transformers.BertForMaskedLM):
elif isinstance(model, transformers.BertForMaskedLM):
# Because BERT does not take a length parameter
alen = torch.arange(token_ids.shape[1], dtype=torch.long)
alen = alen.to(ctx)
Expand All @@ -734,7 +752,7 @@ def sum_accumulated_scores():
# out[0] is what contains the distribution for the masked (batch_size, sequence_length, config.vocab_size)
# Reindex to only get the distributions at the masked positions (batch_size, config.vocab_size)
out = out[0][list(range(split_size)),masked_positions.reshape(-1),:]
elif isinstance(self._model.module, transformers.XLMWithLMHeadModel):
elif isinstance(model, transformers.XLMWithLMHeadModel):
if self._lang is not None and self._tokenizer.lang2id is not None:
langs = torch.ones_like(token_ids)*self._tokenizer.lang2id[self._lang]
else:
Expand Down