Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup serialiazation of TaskReportMap #16217

Merged
merged 13 commits into from
Apr 1, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;

/**
* Interface for the controller of a multi-stage query.
Expand Down Expand Up @@ -123,6 +122,6 @@ void resultsComplete(
List<String> getTaskIds();

@Nullable
Map<String, TaskReport> liveReports();
TaskReport.ReportMap liveReports();

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.server.DruidNode;

import java.util.Map;

/**
* Context used by multi-stage query controllers.
*
Expand Down Expand Up @@ -80,5 +78,5 @@ public interface ControllerContext
/**
* Writes controller task report.
*/
void writeReports(String controllerTaskId, Map<String, TaskReport> reports);
void writeReports(String controllerTaskId, TaskReport.ReportMap reports);
}
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ public void resultsComplete(

@Override
@Nullable
public Map<String, TaskReport> liveReports()
public TaskReport.ReportMap liveReports()
{
final QueryDefinition queryDef = queryDefRef.get();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
import org.apache.druid.segment.realtime.firehose.ChatHandler;
import org.apache.druid.server.DruidNode;

import java.util.Map;

/**
* Implementation for {@link ControllerContext} required to run multi-stage queries as indexing tasks.
*/
Expand Down Expand Up @@ -126,7 +124,7 @@ public WorkerManagerClient workerManager()
}

@Override
public void writeReports(String controllerTaskId, Map<String, TaskReport> reports)
public void writeReports(String controllerTaskId, TaskReport.ReportMap reports)
{
toolbox.getTaskReportFileWriter().write(controllerTaskId, reports);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.List;
import java.util.Map;

public class ControllerChatHandler implements ChatHandler
{
Expand Down Expand Up @@ -189,7 +188,7 @@ public Response httpGetTaskList(@Context final HttpServletRequest req)
public Response httpGetLiveReports(@Context final HttpServletRequest req)
{
ChatHandlers.authorizationCheck(req, Action.WRITE, task.getDataSource(), toolbox.getAuthorizerMapper());
final Map<String, TaskReport> reports = controller.liveReports();
final TaskReport.ReportMap reports = controller.liveReports();
if (reports == null) {
return Response.status(Response.Status.NOT_FOUND).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import javax.ws.rs.core.Response;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;

public class WorkerChatHandlerTest
{
Expand Down Expand Up @@ -88,7 +87,7 @@ public void setUp()
new TaskReportFileWriter()
{
@Override
public void write(String taskId, Map<String, TaskReport> reports)
public void write(String taskId, TaskReport.ReportMap reports)
{

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.msq.indexing.client;

import org.apache.druid.indexing.common.KillTaskReport;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.msq.exec.Controller;
import org.apache.druid.msq.indexing.MSQControllerTask;
import org.apache.druid.server.security.AuthorizerMapper;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.Response;

public class ControllerChatHandlerTest
{
@Test
public void testHttpGetLiveReports()
{
final Controller controller = Mockito.mock(Controller.class);

TaskReport.ReportMap reportMap = new TaskReport.ReportMap();
reportMap.put("killUnusedSegments", new KillTaskReport("kill_1", new KillTaskReport.Stats(1, 2, 3)));

Mockito.when(controller.liveReports())
.thenReturn(reportMap);

MSQControllerTask task = Mockito.mock(MSQControllerTask.class);
Mockito.when(task.getDataSource())
.thenReturn("wiki");
Mockito.when(controller.task())
.thenReturn(task);

TaskToolbox toolbox = Mockito.mock(TaskToolbox.class);
Mockito.when(toolbox.getAuthorizerMapper())
.thenReturn(new AuthorizerMapper(null));

ControllerChatHandler chatHandler = new ControllerChatHandler(toolbox, controller);

HttpServletRequest httpRequest = Mockito.mock(HttpServletRequest.class);
Mockito.when(httpRequest.getAttribute(ArgumentMatchers.anyString()))
.thenReturn("allow-all");
Response response = chatHandler.httpGetLiveReports(httpRequest);

Assert.assertEquals(200, response.getStatus());
Assert.assertEquals(reportMap, response.getEntity());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class MSQTaskReportTest
{
Expand Down Expand Up @@ -242,9 +241,9 @@ public void testWriteTaskReport() throws Exception
writer.setObjectMapper(mapper);
writer.write(TASK_ID, TaskReport.buildTaskReports(report));

final Map<String, TaskReport> reportMap = mapper.readValue(
final TaskReport.ReportMap reportMap = mapper.readValue(
reportFile,
new TypeReference<Map<String, TaskReport>>()
new TypeReference<TaskReport.ReportMap>()
{
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public class MSQTestControllerContext implements ControllerContext
private final ServiceEmitter emitter = new NoopServiceEmitter();

private Controller controller;
private Map<String, TaskReport> report = null;
private TaskReport.ReportMap report = null;
private final WorkerMemoryParameters workerMemoryParameters;

public MSQTestControllerContext(
Expand Down Expand Up @@ -273,14 +273,14 @@ public WorkerClient taskClientFor(Controller controller)
}

@Override
public void writeReports(String controllerTaskId, Map<String, TaskReport> taskReport)
public void writeReports(String controllerTaskId, TaskReport.ReportMap taskReport)
{
if (controller != null && controller.id().equals(controllerTaskId)) {
report = taskReport;
}
}

public Map<String, TaskReport> getAllReports()
public TaskReport.ReportMap getAllReports()
{
return report;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class MSQTestOverlordServiceClient extends NoopOverlordClient
private final WorkerMemoryParameters workerMemoryParameters;
private final List<ImmutableSegmentLoadInfo> loadedSegmentMetadata;
private final Map<String, Controller> inMemoryControllers = new HashMap<>();
private final Map<String, Map<String, TaskReport>> reports = new HashMap<>();
private final Map<String, TaskReport.ReportMap> reports = new HashMap<>();
private final Map<String, MSQControllerTask> inMemoryControllerTask = new HashMap<>();
private final Map<String, TaskStatus> inMemoryTaskStatus = new HashMap<>();

Expand Down Expand Up @@ -171,7 +171,7 @@ public ListenableFuture<TaskStatusResponse> taskStatus(String taskId)

// hooks to pull stuff out for testing
@Nullable
public Map<String, TaskReport> getReportForTask(String id)
public TaskReport.ReportMap getReportForTask(String id)
{
return reports.get(id);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public FrameContext frameContext(QueryDefinition queryDef, int stageNumber)
final TaskReportFileWriter reportFileWriter = new TaskReportFileWriter()
{
@Override
public void write(String taskId, Map<String, TaskReport> reports)
public void write(String taskId, TaskReport.ReportMap reports)
{

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ public Object getPayload()
return stats;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
KillTaskReport that = (KillTaskReport) o;
return Objects.equals(taskId, that.taskId) && Objects.equals(stats, that.stats);
}

@Override
public int hashCode()
{
return Objects.hash(taskId, stats);
}

public static class Stats
{
private final int numSegmentsKilled;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class MultipleFileTaskReportFileWriter implements TaskReportFileWriter
private ObjectMapper objectMapper;

@Override
public void write(String taskId, Map<String, TaskReport> reports)
public void write(String taskId, TaskReport.ReportMap reports)
{
final File reportsFile = taskReportFiles.get(taskId);
if (reportsFile == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@

package org.apache.druid.indexing.common;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.java.util.common.logger.Logger;

import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.util.Map;

public class SingleFileTaskReportFileWriter implements TaskReportFileWriter
{
Expand All @@ -44,7 +40,7 @@ public SingleFileTaskReportFileWriter(File reportsFile)
}

@Override
public void write(String taskId, Map<String, TaskReport> reports)
public void write(String taskId, TaskReport.ReportMap reports)
{
try {
final File reportsFileParent = reportsFile.getParentFile();
Expand All @@ -70,20 +66,9 @@ public void setObjectMapper(ObjectMapper objectMapper)
public static void writeReportToStream(
final ObjectMapper objectMapper,
final OutputStream outputStream,
final Map<String, TaskReport> reports
final TaskReport.ReportMap reports
) throws Exception
{
final SerializerProvider serializers = objectMapper.getSerializerProviderInstance();

try (final JsonGenerator jg = objectMapper.getFactory().createGenerator(outputStream)) {
jg.writeStartObject();

for (final Map.Entry<String, TaskReport> entry : reports.entrySet()) {
jg.writeFieldName(entry.getKey());
JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, entry.getValue());
}

jg.writeEndObject();
}
objectMapper.writeValue(outputStream, reports);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@

import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.google.common.base.Optional;

import java.util.LinkedHashMap;
import java.util.Map;

/**
* TaskReport objects contain additional information about an indexing task, such as row statistics, errors, and
* published segments. They are kept in deep storage along with task logs.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@JsonSubTypes(value = {
@JsonSubTypes.Type(name = "ingestionStatsAndErrors", value = IngestionStatsAndErrorsTaskReport.class),
@JsonSubTypes.Type(
name = IngestionStatsAndErrorsTaskReport.REPORT_KEY,
value = IngestionStatsAndErrorsTaskReport.class
),
@JsonSubTypes.Type(name = KillTaskReport.REPORT_KEY, value = KillTaskReport.class)
})
public interface TaskReport
Expand All @@ -48,13 +51,29 @@ public interface TaskReport
/**
* Returns an order-preserving map that is suitable for passing into {@link TaskReportFileWriter#write}.
*/
static Map<String, TaskReport> buildTaskReports(TaskReport... taskReports)
static ReportMap buildTaskReports(TaskReport... taskReports)
{
// Use LinkedHashMap to preserve order of the reports.
Map<String, TaskReport> taskReportMap = new LinkedHashMap<>();
ReportMap taskReportMap = new ReportMap();
for (TaskReport taskReport : taskReports) {
taskReportMap.put(taskReport.getReportKey(), taskReport);
}
return taskReportMap;
}

/**
* Represents an ordered map from report key to a TaskReport that is compatible
* for writing out reports to files or serving over HTTP.
* <p>
* This class is needed for Jackson serde to work correctly. Without this class,
* a TaskReport is serialized without the type information and cannot be
* deserialized back into a concrete implementation.
*/
class ReportMap extends LinkedHashMap<String, TaskReport>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any tests that verify the reports are indeed ordered since we rely on a LinkedHashMap? Just looking at the callers of buildTaskReports(), I don't seem to find any.

Copy link
Contributor Author

@kfaraz kfaraz Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I can add a test to verify the order. Although, I don't see any actual task writing a report map that contains multiple entries. Also not sure why the order was considered to be important in the first place, its json anyway.

{
@SuppressWarnings("unchecked")
public <T extends TaskReport> Optional<T> findReport(String reportKey)
{
return Optional.fromNullable((T) get(reportKey));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@

import com.fasterxml.jackson.databind.ObjectMapper;

import java.util.Map;

public interface TaskReportFileWriter
{
void write(String taskId, Map<String, TaskReport> reports);
void write(String taskId, TaskReport.ReportMap reports);

void setObjectMapper(ObjectMapper objectMapper);
}