## 問題

https://twitter.com/e869120/status/1390798852299448322

## 解説

https://twitter.com/e869120/status/1391218516129312768

In [2]:
#r "nuget: FSharpPlus"

In [3]:
open System
open System.Numerics
open System.Collections.Generic
open FSharpPlus

In [4]:
type Node = {
  Id: int
  Parent: int option
  Children: int list
  Depth: int
}
type Tree = {
  NodeDict: IReadOnlyDictionary<int, Node>
  Root: int
  MaxDepth: int
} with
  member x.Count = x.NodeDict.Count

In [5]:
let getTree rootId (edges: (int * int) seq) =
  let nodeDict = edges |> Seq.collect (fun (a, b) -> [a, b; b, a]) |> Seq.groupBy fst |> Seq.map (fun (k, vs) -> k, vs |> Seq.map snd |> Seq.toList) |> dict
  let rec loop (result: Dictionary<int, Node>) depth = function
    | [] -> { Root = rootId; NodeDict = result.AsReadOnly(); MaxDepth = depth - 1 }
    | nodeIdParents ->
      let nodes =
        nodeIdParents |> List.distinct |> List.map (fun (nodeId, parent) ->
          let children = nodeDict[nodeId] |> filter (fun n -> result.ContainsKey(n) |> not)
          { Id = nodeId; Parent = parent; Children = children; Depth = depth })
      nodes |> iter (fun n -> result.Add(n.Id, n))
      nodes |> List.collect (fun n -> n.Children |> map (fun c -> c, Some n.Id)) |> loop result (depth + 1)
  loop (Dictionary()) 0 [rootId, None]

In [6]:
let calcDoubling tree =
  let len = Seq.initInfinite id |> Seq.skipWhile (fun i -> 0 < (tree.MaxDepth >>> i)) |> Seq.head
  let dic = Dictionary<int * int, int>()
  tree.NodeDict |> Seq.iter (fun (KeyValue(nodeId, node)) -> dic.Add((0, nodeId), node.Parent |> Option.defaultValue tree.Root))
  for i = 1 to len - 1 do
    tree.NodeDict.Keys |> Seq.iter (fun nodeId -> dic.Add((i, nodeId), dic[i - 1, dic[i - 1, nodeId]]))
  dic.AsReadOnly(), len

In [7]:
let calcLowestCommonAncestor (doubling: IReadOnlyDictionary<int * int, int>) len tree x y =
  let getAncestor generation nodeId =
    let indexes = (generation, 0) |> Seq.unfold (function 0, _ -> None | g, e -> Some((if g &&& 1 = 0 then None else Some e), (g >>> 1, e + 1))) |> Seq.choose id |> Seq.toList |> List.rev
    (nodeId, indexes) ||> List.fold (fun n i -> doubling[i, n])
  let x, y =
    let depthDiff = tree.NodeDict[x].Depth - tree.NodeDict[y].Depth
    if depthDiff = 0 then x, y elif depthDiff < 0 then x, getAncestor (abs depthDiff) y else getAncestor depthDiff x, y
  if x = y then x else
  let px, _ =
    ((x, y), seq { len - 1 .. -1 .. 0 }) ||> Seq.fold (fun (x, y) i ->
      let px, py = doubling[i, x], doubling[i, y]
      if px = py then x, y else px, py)
  doubling[0, px]

In [8]:
let calcDistance doubling len tree x y =
  tree.NodeDict[x].Depth + tree.NodeDict[y].Depth - 2 * tree.NodeDict[calcLowestCommonAncestor doubling len tree x y].Depth

In [9]:
let solve ABs (Qs: int seq seq) =
  let tree = getTree 1 ABs
  let doubling, len = calcDoubling tree
  let nodeIndexDict =
    let rec dfs index nodeId =
      let results, maxIndex =
        (([], index + 1), tree.NodeDict[nodeId].Children) ||> List.fold (fun (results, i) c ->
          let rs, maxIndex = dfs i c
          rs @ results, maxIndex + 1)
      (nodeId, index)::results, max index maxIndex
    dfs 0 tree.Root |> fst |> dict
  Qs |> Seq.map (fun vs ->
    let nodes = vs |> Seq.sortBy (fun q -> nodeIndexDict[q]) |> Seq.toList
    nodes @ [head nodes] |> Seq.pairwise |> Seq.sumBy (uncurry (calcDistance doubling len tree)) |> flip (/) 2)
  |> Seq.toList

In [10]:
solve
  [
    1, 2
    2, 3
    3, 4
    1, 5
    3, 6
  ]
  [
    [1; 2]
    [1; 3; 5]
    [2; 3; 4; 5]
    [1; 2; 3; 5; 6]
    [1; 2; 3; 4; 5; 6]
  ]

index,value
0,1
1,3
2,4
3,4
4,5


In [11]:
solve
  [
    1, 2
    2, 3
    3, 4
    1, 5
    3, 6
  ]
  [
    [1; 2]
    [3; 4]
    [4; 6]
    [1; 5]
    [2; 5]
  ]

index,value
0,1
1,1
2,2
3,1
4,2


In [14]:
solve
  [
    1, 2
    2, 3
    3, 4
    1, 5
    3, 6    
  ]
  [
    [1; 2; 3]
    [1; 2; 5]
    [1; 3; 6]
    [3; 4; 5]
    [4; 5; 6]
  ]

index,value
0,2
1,2
2,3
3,4
4,5
