diff --git a/Research/QuantBook.cs b/Research/QuantBook.cs index d7577cad7e7e..feaff26c4c29 100644 --- a/Research/QuantBook.cs +++ b/Research/QuantBook.cs @@ -940,6 +940,7 @@ private IEnumerable RunUniverseSelection(Universe universe, var history = History(universe, start, endDate); HashSet filteredSymbols = null; + SecurityType? filteredSecurityType = null; Func processDataPoint = dataPoint => { var utcTime = dataPoint.EndTime.ConvertToUtc(universe.Configuration.ExchangeTimeZone); @@ -947,8 +948,9 @@ private IEnumerable RunUniverseSelection(Universe universe, if (!ReferenceEquals(selection, Universe.Unchanged)) { filteredSymbols = selection.ToHashSet(); + filteredSecurityType = filteredSymbols.FirstOrDefault()?.ID.SecurityType; } - dataPoint.Data = dataPoint.Data.Where(x => filteredSymbols == null || filteredSymbols.Contains(x.Symbol)).ToList(); + dataPoint.Data = FilterUniverseData(dataPoint.Data, filteredSymbols, filteredSecurityType); return dataPoint; }; @@ -957,6 +959,19 @@ private IEnumerable RunUniverseSelection(Universe universe, return PerformSelection(history, processDataPoint, getTime, start, endDate, dateRule); } + private static List FilterUniverseData(List data, HashSet filteredSymbols, SecurityType? filteredSecurityType) + { + return data.Where(x => + filteredSymbols == null || + filteredSymbols.Contains(x.Symbol) || + (filteredSecurityType.HasValue && x.Symbol.SecurityType != filteredSecurityType.Value && + filteredSymbols.Any(s => + CurrencyPairUtil.TryDecomposeCurrencyPair(s, out var baseCurrency, out var quoteCurrency) && + (x.Symbol.Value.Equals(baseCurrency, StringComparison.OrdinalIgnoreCase) || + x.Symbol.Value.Equals(quoteCurrency, StringComparison.OrdinalIgnoreCase)))) + ).ToList(); + } + /// /// Converts a pandas.Series into a /// diff --git a/Tests/Research/QuantBookSelectionTests.cs b/Tests/Research/QuantBookSelectionTests.cs index f8f82407fdeb..d3fc2865db4a 100644 --- a/Tests/Research/QuantBookSelectionTests.cs +++ b/Tests/Research/QuantBookSelectionTests.cs @@ -15,13 +15,17 @@ using System; using System.Linq; +using System.Reflection; using Python.Runtime; using NUnit.Framework; using QuantConnect.Research; using System.Collections.Generic; +using QuantConnect.Data; using QuantConnect.Data.Fundamental; +using QuantConnect.Data.Market; using QuantConnect.Data.UniverseSelection; using QuantConnect.Scheduling; +using QuantConnect.Util; namespace QuantConnect.Tests.Research { @@ -824,6 +828,25 @@ def getUniverseHistory(self, qb, start, end, symbol, flatten): } } + [Test] + public void FilterUniverseDataFallsBackToCurrencyPairMatchingWhenSecurityTypesDiffer() + { + var eurBase = new Symbol(SecurityIdentifier.GenerateBase(typeof(Fundamental), "EUR", Market.USA), "EUR"); + var gbpBase = new Symbol(SecurityIdentifier.GenerateBase(typeof(Fundamental), "GBP", Market.USA), "GBP"); + var eurUsd = Symbol.Create("EURUSD", SecurityType.Forex, Market.Oanda); + + // selection returns Forex — security type differs from Base data symbols + var filteredSymbols = new HashSet { eurUsd }; + var data = new List { new Tick { Symbol = eurBase }, new Tick { Symbol = gbpBase } }; + + var filterMethod = typeof(QuantBook).GetMethod("FilterUniverseData", BindingFlags.NonPublic | BindingFlags.Static); + var result = (List)filterMethod.Invoke(null, new object[] { data, filteredSymbols, (SecurityType?)eurUsd.SecurityType }); + + // EUR matches EURUSD base currency, GBP matches neither + Assert.AreEqual(1, result.Count); + Assert.AreEqual(eurBase, result[0].Symbol); + } + [Test] public void PerformSelectionDoesNotSkipDataPointWhenPreviousDataPointIsYielded() { @@ -875,7 +898,7 @@ private static string GetBaseImplementation(int expectedCount, string identation .Replace("{identation}", identation, StringComparison.InvariantCulture); } - private class QuantBookTestClass: QuantBook + private class QuantBookTestClass : QuantBook { public static IEnumerable PerformSelection(IEnumerable history, DateTime start, DateTime end, IDateRule dateRule) {