Skip to content

Commit

Permalink
Merge pull request #8238 from enebo/regexp_timeout
Browse files Browse the repository at this point in the history
Implement regexp timeout
  • Loading branch information
enebo committed May 16, 2024
2 parents 1072499 + 3b79ce9 commit 0c13b12
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 32 deletions.
22 changes: 22 additions & 0 deletions core/src/main/java/org/jruby/Ruby.java
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,8 @@ private void initExceptions() {
ifAllowed("KeyError", (ruby) -> keyError = RubyKeyError.define(ruby, indexError));
ifAllowed("DomainError", (ruby) -> mathDomainError = RubyDomainError.define(ruby, argumentError, mathModule));

setRegexpTimeoutError(regexpClass.defineClassUnder("TimeoutError", getRegexpError(), RubyRegexpError::new));

RubyClass runtimeError = this.runtimeError;
ObjectAllocator runtimeErrorAllocator = runtimeError.getAllocator();

Expand Down Expand Up @@ -4930,6 +4932,22 @@ public RubyBinding getTopLevelBinding() {
return topLevelBinding;
}

public void setRubyTimeout(IRubyObject timeout) {
this.regexpTimeout = timeout;
}

public IRubyObject getRubyTimeout() {
return regexpTimeout;
}

public void setRegexpTimeoutError(RubyClass error) {
this.regexpTimeoutError = error;
}

public RubyClass getRegexpTimeoutError() {
return regexpTimeoutError;
}

static class FStringEqual {
RubyString string;
public boolean equals(Object other) {
Expand Down Expand Up @@ -5339,6 +5357,10 @@ public RubyClass getData() {
private volatile boolean objectSpaceEnabled;
private boolean siphashEnabled;

// Global timeout value. Nil == no timeout set.
private IRubyObject regexpTimeout;
private RubyClass regexpTimeoutError;

@Deprecated
private long globalState = 1;

Expand Down
133 changes: 102 additions & 31 deletions core/src/main/java/org/jruby/RubyRegexp.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.joni.Syntax;
import org.joni.WarnCallback;
import org.joni.exception.JOniException;
import org.joni.exception.TimeoutException;
import org.jruby.anno.JRubyClass;
import org.jruby.anno.JRubyMethod;
import org.jruby.common.IRubyWarnings.ID;
Expand All @@ -78,6 +79,7 @@

import static org.jruby.api.Convert.castToHash;
import static org.jruby.api.Error.typeError;
import static org.jruby.runtime.ThreadContext.resetCallInfo;
import static org.jruby.util.RubyStringBuilder.str;
import static org.jruby.util.StringSupport.CR_7BIT;
import static org.jruby.util.StringSupport.EMPTY_STRING_ARRAY;
Expand All @@ -90,6 +92,8 @@ public class RubyRegexp extends RubyObject implements ReOptions, EncodingCapable
private ByteList str = ByteList.EMPTY_BYTELIST;
private RegexpOptions options;

private IRubyObject timeout;

private static final ThreadLocal<IRubyObject[]> TL_HOLDER = ThreadLocal.withInitial(() -> new IRubyObject[1]);

public static final int ARG_ENCODING_FIXED = ReOptions.RE_FIXED;
Expand Down Expand Up @@ -213,6 +217,8 @@ public static RubyClass createRegexpClass(Ruby runtime) {
regexpClass.defineAnnotatedMethods(RubyRegexp.class);
regexpClass.getSingletonClass().defineAlias("compile", "new");

runtime.setRubyTimeout(runtime.getNil());

return regexpClass;
}

Expand All @@ -221,6 +227,8 @@ public static int matcherSearch(ThreadContext context, Matcher matcher, int star

try {
return context.getThread().executeRegexp(context, matcher, start, range, option, Matcher::searchInterruptible);
} catch (TimeoutException e) {
throw context.runtime.newRaiseException(context.runtime.getRegexpTimeoutError(), "regexp match timeout");
} catch (InterruptedException e) {
throw context.runtime.newInterruptedRegexpError("Regexp Interrupted");
}
Expand All @@ -231,6 +239,8 @@ public static int matcherMatch(ThreadContext context, Matcher matcher, int start

try {
return context.getThread().executeRegexp(context, matcher, start, range, option, Matcher::matchInterruptible);
} catch (TimeoutException e) {
throw context.runtime.newRaiseException(context.runtime.getRegexpTimeoutError(), "regexp match timeout");
} catch (InterruptedException e) {
throw context.runtime.newInterruptedRegexpError("Regexp Interrupted");
}
Expand Down Expand Up @@ -283,7 +293,7 @@ private RubyRegexp(Ruby runtime, ByteList str, RegexpOptions options) {
this(runtime);
assert str != null;

regexpInitialize(str, str.getEncoding(), options);
regexpInitialize(str, str.getEncoding(), options, null);
}

// used only by the compiler/interpreter (will set the literal flag)
Expand Down Expand Up @@ -358,7 +368,7 @@ static RubyRegexp newDummyRegexp(Ruby runtime, Regex regex) {
// MRI: rb_reg_new_str
public static RubyRegexp newRegexpFromStr(Ruby runtime, RubyString s, int options) {
RubyRegexp re = (RubyRegexp)runtime.getRegexp().allocate();
re.regexpInitializeString(s, RegexpOptions.fromJoniOptions(options));
re.regexpInitializeString(s, RegexpOptions.fromJoniOptions(options), null);
return re;
}

Expand Down Expand Up @@ -871,7 +881,7 @@ public IRubyObject initialize_copy(IRubyObject re) {
RubyRegexp regexp = (RubyRegexp)re;
regexp.check();

return regexpInitialize(regexp.str, regexp.str.getEncoding(), regexp.getOptions());
return regexpInitialize(regexp.str, regexp.str.getEncoding(), regexp.getOptions(), regexp.timeout);
}

private static int objectAsJoniOptions(IRubyObject arg) {
Expand All @@ -892,52 +902,66 @@ private static int objectAsJoniOptions(IRubyObject arg) {

@JRubyMethod(name = "initialize", visibility = Visibility.PRIVATE)
public IRubyObject initialize_m(IRubyObject arg) {
if (arg instanceof RubyRegexp) return initializeByRegexp((RubyRegexp)arg);
return regexpInitializeString(arg.convertToString(), new RegexpOptions());
return arg instanceof RubyRegexp regexp ?
initializeByRegexp(regexp, null) :
regexpInitializeString(arg.convertToString(), new RegexpOptions(), null);
}

@JRubyMethod(name = "initialize", visibility = Visibility.PRIVATE)
@JRubyMethod(name = "initialize", visibility = Visibility.PRIVATE, keywords = true)
public IRubyObject initialize_m(IRubyObject arg0, IRubyObject arg1) {
if (arg0 instanceof RubyRegexp && Options.PARSER_WARN_FLAGS_IGNORED.load()) {
metaClass.runtime.getWarnings().warn(ID.REGEXP_IGNORED_FLAGS, "flags ignored");
return initializeByRegexp((RubyRegexp)arg0);
ThreadContext context = getRuntime().getCurrentContext();
boolean keywords = (resetCallInfo(context) & ThreadContext.CALL_KEYWORD) != 0;


IRubyObject timeout;
RegexpOptions regexpOptions;
if (keywords) {
regexpOptions = new RegexpOptions();
timeout = timeoutFromArg(context, arg1);
if (arg0 instanceof RubyRegexp) return initializeByRegexp((RubyRegexp) arg0, timeout);
} else {
if (arg0 instanceof RubyRegexp && Options.PARSER_WARN_FLAGS_IGNORED.load()) {
metaClass.runtime.getWarnings().warn(ID.REGEXP_IGNORED_FLAGS, "flags ignored");
return initializeByRegexp((RubyRegexp)arg0, null);
}
regexpOptions = RegexpOptions.fromJoniOptions(objectAsJoniOptions(arg1));
timeout = null;
}

return regexpInitializeString(arg0.convertToString(),
RegexpOptions.fromJoniOptions(objectAsJoniOptions(arg1)));
return regexpInitializeString(arg0.convertToString(), regexpOptions, timeout);
}

@JRubyMethod(name = "initialize", visibility = Visibility.PRIVATE)
@JRubyMethod(name = "initialize", visibility = Visibility.PRIVATE, keywords = true)
public IRubyObject initialize_m(IRubyObject arg0, IRubyObject arg1, IRubyObject arg2) {
ThreadContext context = getRuntime().getCurrentContext();
boolean keywords = (resetCallInfo(context) & ThreadContext.CALL_KEYWORD) != 0;

if (arg0 instanceof RubyRegexp && Options.PARSER_WARN_FLAGS_IGNORED.load()) {
metaClass.runtime.getWarnings().warn(ID.REGEXP_IGNORED_FLAGS, "flags ignored");
return initializeByRegexp((RubyRegexp)arg0);
return initializeByRegexp((RubyRegexp)arg0, timeoutFromArg(context, arg2));
}

RegexpOptions newOptions = RegexpOptions.fromJoniOptions(objectAsJoniOptions(arg1));
if (!keywords) throw getRuntime().newArgumentError(3, 1, 2);

if (!arg2.isNil()) {
ByteList kcodeBytes = arg2.convertToString().getByteList();
if (kcodeBytes.getRealSize() > 0 && (kcodeBytes.get(0) == 'n' || kcodeBytes.get(0) == 'N')) {
newOptions.setEncodingNone(true);
return regexpInitialize(arg0.convertToString().getByteList(), ASCIIEncoding.INSTANCE, newOptions);
} else {
metaClass.runtime.getWarnings().warnDeprecated("encoding option is ignored - " + kcodeBytes);
}
}
return regexpInitializeString(arg0.convertToString(), newOptions);
return regexpInitializeString(arg0.convertToString(), newOptions, timeoutFromArg(context, arg2));
}

private IRubyObject initializeByRegexp(RubyRegexp regexp) {
private IRubyObject timeoutFromArg(ThreadContext context, IRubyObject arg) {
RubyHash kwargs = castToHash(context, arg);
return kwargs.fastARef(context.runtime.newSymbol("timeout"));
}

private IRubyObject initializeByRegexp(RubyRegexp regexp, IRubyObject timeoutProvided) {
// Clone and toggle flags since this is no longer a literal regular expression
// but it did come from one.
RegexpOptions newOptions = regexp.getOptions().clone();
newOptions.setLiteral(false);
return regexpInitialize(regexp.str, regexp.getEncoding(), newOptions);
return regexpInitialize(regexp.str, regexp.getEncoding(), newOptions, timeoutProvided != null ? timeoutProvided : regexp.timeout);
}

// rb_reg_initialize_str
private RubyRegexp regexpInitializeString(RubyString str, RegexpOptions options) {
private RubyRegexp regexpInitializeString(RubyString str, RegexpOptions options, IRubyObject timeout) {
if (isLiteral()) throw metaClass.runtime.newFrozenError(this);
ByteList bytes = str.getByteList();
Encoding enc = bytes.getEncoding();
Expand All @@ -949,13 +973,18 @@ private RubyRegexp regexpInitializeString(RubyString str, RegexpOptions options)
enc = ASCIIEncoding.INSTANCE;
}
}
return regexpInitialize(bytes, enc, options);
return regexpInitialize(bytes, enc, options, timeout);
}

// rb_reg_initialize
@Deprecated
public final RubyRegexp regexpInitialize(ByteList bytes, Encoding enc, RegexpOptions options) {
return regexpInitialize(bytes, enc, options, null);
}
// rb_reg_initialize
public final RubyRegexp regexpInitialize(ByteList bytes, Encoding enc, RegexpOptions options, IRubyObject timeout) {
Ruby runtime = metaClass.runtime;
this.options = options;
this.timeout = processTimeoutArg(runtime.getCurrentContext(), timeout);

checkFrozen();
// FIXME: Something unsets this bit, but we aren't...be more permissive until we figure this out
Expand Down Expand Up @@ -1167,8 +1196,9 @@ final RubyBoolean matchP(ThreadContext context, RubyString str, int pos) {
final Regex reg = preparePattern(str);
final ByteList strBL = str.getByteList();
final int beg = strBL.begin();
final long timeout = getRegexpTimeout(context);

Matcher matcher = reg.matcherNoRegion(strBL.unsafeBytes(), beg, beg + strBL.realSize());
Matcher matcher = reg.matcherNoRegion(strBL.unsafeBytes(), beg, beg + strBL.realSize(), timeout);

try {
final int result = matcherSearch(context, matcher, beg + pos, beg + strBL.realSize(), RE_OPTION_NONE);
Expand All @@ -1178,6 +1208,46 @@ final RubyBoolean matchP(ThreadContext context, RubyString str, int pos) {
}
}

@JRubyMethod(meta = true, name = "timeout=")
public static IRubyObject timeout_set(ThreadContext context, IRubyObject recv, IRubyObject timeout) {
context.runtime.setRubyTimeout(processTimeoutArg(context, timeout));
return timeout;
}

private static double MAX_TIMEOUT_VALUE = 18446744073.709553; // ((1<<64)-1) / 1000000000.0

private static IRubyObject processTimeoutArg(ThreadContext context, IRubyObject timeout) {
if (timeout == null) return null;
if (timeout.isNil()) return context.nil;

RubyFloat converted = timeout.convertToFloat();

if (converted.isInfinite() || converted.value > MAX_TIMEOUT_VALUE) converted = context.runtime.newFloat(MAX_TIMEOUT_VALUE);

if (converted.value <= 0) throw context.runtime.newArgumentError("invalid timeout: " + timeout);

return converted;
}

@JRubyMethod(meta = true, name = "timeout")
public static IRubyObject timeout(ThreadContext context, IRubyObject recv) {
return context.runtime.getRubyTimeout();
}

@JRubyMethod(name = "timeout")
public IRubyObject timeout(ThreadContext context) {
return timeout == null ? context.nil : timeout;
}

// float s in ns
private long getRegexpTimeout(ThreadContext context) {
IRubyObject timeout = this.timeout;
if (timeout != null && timeout.isNil()) return -1; // local override to ignore global timeout.
if (timeout == null) timeout = context.runtime.getRubyTimeout();

return timeout.isNil() ? -1 : (long) (timeout.convertToFloat().getDoubleValue() * 1_000_000_000);
}

/**
* MRI: rb_reg_search
*
Expand Down Expand Up @@ -1255,7 +1325,8 @@ public final int searchString(ThreadContext context, RubyString str, int pos, bo

if (!reverse) range += str.size();

final Matcher matcher = reg.matcher(strBL.unsafeBytes(), beg, beg + strBL.realSize());
final long timeout = getRegexpTimeout(context);
final Matcher matcher = reg.matcher(strBL.unsafeBytes(), beg, beg + strBL.realSize(), timeout);

try {
int result = matcherSearch(context, matcher, beg + pos, range, RE_OPTION_NONE);
Expand Down Expand Up @@ -1503,7 +1574,7 @@ private record RegexpArgs(RubyString string, int options, IRubyObject timeout) {
// MRI: reg_extract_args - This does not break the regexp into a String value since it will never used if the first
// argument is a Regexp. This also is true of MRI so I am not sure why they do the string part.
private static RegexpArgs extractRegexpArgs(ThreadContext context, IRubyObject[] args) {
int callInfo = ThreadContext.resetCallInfo(context);
int callInfo = resetCallInfo(context);
int length = args.length;

IRubyObject timeout = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ private IRubyObject unmarshalRegexp(MarshalState state) throws IOException {
byteList.setRealSize(dst - ptr);
}

regexp.regexpInitialize(byteList, byteList.getEncoding(), reOpts);
regexp.regexpInitialize(byteList, byteList.getEncoding(), reOpts, null);

if (ivarHolder != null) {
ivarHolder.getInstanceVariables().copyInstanceVariablesInto(regexp);
Expand Down

0 comments on commit 0c13b12

Please sign in to comment.