In [1]:
versioninfo()

Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, skylake)
  Threads: 1 on 12 virtual cores


## 8-6. 生成関数

### 8-6-1. 生成関数 (Generated Function) とは

#### コード8-46. 簡単な生成関数の例（`printwithtype()`）

In [2]:
@generated function printwithtype(x)
    :(print(x, " is a type of ", $x))
end

printwithtype (generic function with 1 method)

#### コード8-47. `printwithtype()` 関数の動作確認

In [3]:
printwithtype(1)

1 is a type of Int64

In [4]:
printwithtype("文字列")

文字列 is a type of String

In [5]:
printwithtype(rand())

0.4429312826527251 is a type of Float64

In [6]:
printwithtype(Int)

Int64 is a type of Type{Int64}

In [7]:
@code_lowered printwithtype(1)

CodeInfo(
   [33m @ In[2]:1 within `printwithtype`[39m
   [33m┌ @ In[2]:1 within `macro expansion`[39m
[90m1 ─[39m[33m│[39m %1 = Main.print(x, " is a type of ", Int64)
[90m└──[39m[33m│[39m      return %1
   [33m└[39m
)

#### コード8-48. 生成関数の例(2)（`grand1to10()`）

In [8]:
@generated function grand1to10(_DUMMY...)
    rand(1:10)
end

grand1to10 (generic function with 1 method)

In [9]:
grand1to10()

3

In [10]:
grand1to10()

3

In [11]:
grand1to10()

3

In [12]:
grand1to10(3)

9

In [13]:
grand1to10(5)

9

In [14]:
grand1to10(123)

9

In [15]:
grand1to10(Int, "123")

10

In [16]:
grand1to10(Int, "他の文字列")

10

#### コード8-49. 生成関数の例(3)（`gpredtypes()`）

In [17]:
@generated function gpredtypes(x, y)
    if x !== y
        :("$x(::$($x)) and $y(::$($y)) are diffrent types.")
    elseif x === Int
        :("Both $x and $y are `Int`.")
    elseif x <: Number
        :("Both $x and $y are the same Number types ($($x)).")
    else
        :("Both $x and $y are the same types ($($x)).")
    end
end

gpredtypes (generic function with 1 method)

In [18]:
gpredtypes(1, 2)

"Both 1 and 2 are `Int`."

In [19]:
gpredtypes("abc", "漢字")

"Both abc and 漢字 are the same types (String)."

In [20]:
gpredtypes(:ok, :NG)

"Both ok and NG are the same types (Symbol)."

In [21]:
gpredtypes(1.0, 2.2)

"Both 1.0 and 2.2 are the same Number types (Float64)."

In [22]:
gpredtypes(1//2, 3//4)

"Both 1//2 and 3//4 are the same Number types (Rational{Int64})."

In [23]:
gpredtypes(1.0, π)

"1.0(::Float64) and π(::Irrational{:π}) are diffrent types."

### 8-6-2. 生成関数の特徴と注意点

#### コード8-50. 生成関数のNG例(1)

In [24]:
_double_impl(::Type{<:AbstractString}) = :(a^2)

_double_impl (generic function with 1 method)

In [25]:
@generated gdouble(a) = _double_impl(a)

gdouble (generic function with 1 method)

In [26]:
_double_impl(::Type{<:Number}) = :(2a)

_double_impl (generic function with 2 methods)

In [27]:
gdouble("a")

"aa"

In [28]:
gdouble(2)

LoadError: MethodError: no method matching _double_impl(::Type{Int64})
The applicable method may be too new: running in world age 32459, while current world is 32460.
[0mClosest candidates are:
[0m  _double_impl(::Type{<:Number}) at In[26]:1 (method too new to be called from this world context.)
[0m  _double_impl([91m::Type{<:AbstractString}[39m) at In[24]:1

#### コード8-51. 生成関数のNG例(2)

In [29]:
isiterable(::Type{T}) where {T} = hasmethod(iterate, Tuple{T})

isiterable (generic function with 1 method)

In [30]:
@generated function checkanditerate(itr)
    if isiterable(itr)
        quote
            for v in itr
                println(v)
            end
        end
    end
end

checkanditerate (generic function with 1 method)

In [31]:
checkanditerate([1, 2, 3])

1
2
3


In [32]:
struct WrapArray{T <: AbstractArray}
    arr::T
end

In [33]:
a = WrapArray([1, 2, 3])

WrapArray{Vector{Int64}}([1, 2, 3])

In [34]:
checkanditerate(a)  # 何も実行されない（正確には `nohitng` が評価されるのみ）

In [35]:
Base.iterate(a::WrapArray, st...) = iterate(a.arr, st...)  # 後で `iterate()` を実装

In [36]:
for v in a
    println(v)
end

1
2
3


In [37]:
checkanditerate(a)  # やっぱり何も実行されない

### 8-6-3. 実例

#### 仮想コード8-5. `bdot()` 関数（仕様）

```julia
julia> x = Float32[1, 2, 3];

julia> v = [2, 3, 1];

julia> bdot(v, x)  # == `dot(v, x)` と同等（`dot()` は要 `using LinearAlgebra`）
11.0f0

julia> A = [1 4 7; 2 5 8; 3 6 9]
3×3 Matrix{Int64}:
 1  4  7
 2  5  8
 3  6  9

julia> bdot(A, x)  # [col' * x for col in eachcol(A)]  相当
3-element Vector{Float32}:
 14.0
 32.0
 50.0

julia> A3 = reshape(1:24, (3, 4, 2))
3×4×2 reshape(::UnitRange{Int64}, 3, 4, 2) with eltype Int64:
[:, :, 1] =
 1  4  7  10
 2  5  8  11
 3  6  9  12

[:, :, 2] =
 13  16  19  22
 14  17  20  23
 15  18  21  24

julia> bdot(A3, x)
4×2 Matrix{Float32}:
 14.0   86.0
 32.0  104.0
 50.0  122.0
 68.0  140.0

julia> # 以下、4次元以上の多次元配列でも同様
```

#### コード8-52. `bdot()` 関数相当のコード（挙動の確認）

In [38]:
x = Float32[1, 2, 3]

3-element Vector{Float32}:
 1.0
 2.0
 3.0

In [39]:
v = 1:3

1:3

In [40]:
v' * x  # == `dot(v, x)`（ただし要 `using LinearAlgebra`）

14.0f0

In [41]:
A = reshape(1:6, (3, 2))

3×2 reshape(::UnitRange{Int64}, 3, 2) with eltype Int64:
 1  4
 2  5
 3  6

In [42]:
A' * x

2-element Vector{Float32}:
 14.0
 32.0

In [43]:
y = zeros(Float32, 2);

In [44]:
for j = 1:2
    for i = 1:3
        y[j] += A[i, j] * x[i]
    end
end

In [45]:
y  # == A' * x

2-element Vector{Float32}:
 14.0
 32.0

In [46]:
A3 = reshape(1:12, (3, 2, 2));
B = zeros(Float32, (2, 2));

In [47]:
for i_3 in axes(A3, 3)
    for i_2 in axes(A3, 2)
        for i_1 in axes(A3, 1)
            B[i_2, i_3] += A3[i_1, i_2, i_3] * x[i_1]
        end
    end
end

In [48]:
B

2×2 Matrix{Float32}:
 14.0  50.0
 32.0  68.0

#### コード8-53. `bdot()` 関数の実装

In [49]:
bdot(A::AbstractVecOrMat, x::AbstractVector) = A' * x

bdot (generic function with 1 method)

In [50]:
@generated function bdot(A::AbstractArray{T1, N}, x::AbstractVector{T2}) where {T1, T2, N}
    T = promote_type(T1, T2)
    idxs = ntuple(d->Symbol("i_", d), Val(N))
    ex = :(y[$(idxs[2:end]...)] += A[$(idxs...)] * x[$(idxs[1])])
    for d = 1:N
        ex = :(for $(idxs[d]) in axes(A, $d)
            $ex
        end)
    end
    quote
        y = zeros($T, size(A)[2:$N])
        $ex
        y
    end
end

bdot (generic function with 2 methods)

#### コード8-54. `bdot()` 関数の動作確認

In [51]:
x = Float32[1, 2, 3];

In [52]:
v = 1:3;
bdot(v, x)  # == v' * x

14.0f0

In [53]:
A = reshape(1:6, (3, 2));
bdot(A, x)  # == A' * x

2-element Vector{Float32}:
 14.0
 32.0

In [54]:
A3 = reshape(1:12, (3, 2, 2));
bdot(A3, x)

2×2 Matrix{Float32}:
 14.0  50.0
 32.0  68.0

In [55]:
A4 = reshape(1:24, (3, 2, 2, 2));
bdot(A4, x)

2×2×2 Array{Float32, 3}:
[:, :, 1] =
 14.0  50.0
 32.0  68.0

[:, :, 2] =
  86.0  122.0
 104.0  140.0

#### コード8-55. `bdot()` の生成する引用表現の確認

In [56]:
function bdot_impl(::AbstractArray{T1, N}, ::AbstractVector{T2}) where {T1, T2, N}
    T = promote_type(T1, T2)
    idxs = ntuple(d->Symbol("i_", d), Val(N))
    ex = :(y[$(idxs[2:end]...)] += A[$(idxs...)] * x[$(idxs[1])])
    for d = 1:N
        ex = :(for $(idxs[d]) in axes(A, $d)
            $ex
        end)
    end
    quote
        y = zeros($T, size(A)[2:$N])
        $ex
        y
    end
end

bdot_impl (generic function with 1 method)

In [57]:
bdot_impl(A, x) |> Base.remove_linenums!  # `A isa AbstractArray{Int, 2}`

quote
    y = zeros(Float32, (size(A))[2:2])
    for i_2 = axes(A, 2)
        for i_1 = axes(A, 1)
            y[i_2] += A[i_1, i_2] * x[i_1]
        end
    end
    y
end

In [58]:
bdot_impl(A3, x) |> Base.remove_linenums!  # `A3 isa AbstractArray{Int, 3}

quote
    y = zeros(Float32, (size(A))[2:3])
    for i_3 = axes(A, 3)
        for i_2 = axes(A, 2)
            for i_1 = axes(A, 1)
                y[i_2, i_3] += A[i_1, i_2, i_3] * x[i_1]
            end
        end
    end
    y
end

In [59]:
bdot_impl(A4, x) |> Base.remove_linenums!  # `A4 isa AbstractArray{Int, 4}

quote
    y = zeros(Float32, (size(A))[2:4])
    for i_4 = axes(A, 4)
        for i_3 = axes(A, 3)
            for i_2 = axes(A, 2)
                for i_1 = axes(A, 1)
                    y[i_2, i_3, i_4] += A[i_1, i_2, i_3, i_4] * x[i_1]
                end
            end
        end
    end
    y
end

In [60]:
# 以下略

#### コード8-56. `bdot()` の他の実装方法とパフォーマンス比較

In [61]:
bdot2(A::AbstractVecOrMat, x::AbstractVector) = A' * x

bdot2 (generic function with 1 method)

In [62]:
bdot2(A::AbstractArray, x::AbstractVector) =
    [sum(A[i, j] * x[i] for i in axes(A, 1))
     for j in CartesianIndices(axes(A)[2:end])]

bdot2 (generic function with 2 methods)

In [63]:
bdot3(A::AbstractVecOrMat, x::AbstractVector) = A' * x

bdot3 (generic function with 1 method)

In [64]:
bdot3(A::AbstractArray, x::AbstractVector) = 
    reshape(sum(A .* x, dims=1), size(A)[2:end])

bdot3 (generic function with 2 methods)

In [65]:
bdot(v, x) == bdot2(v, x) == bdot3(v, x)

true

In [66]:
bdot(A, x) == bdot2(A, x) == bdot3(A, x)

true

In [67]:
bdot(A3, x) == bdot2(A3, x) == bdot3(A3, x)

true

In [68]:
bdot(A4, x) == bdot2(A4, x) == bdot3(A4, x)

true

In [69]:
A5 = reshape(1:48, (3, 2, 2, 2, 2));

In [70]:
bdot(A5, x) == bdot2(A5, x) == bdot3(A5, x)

true

In [71]:
@time bdot(A5, x);

  0.000005 seconds (2 allocations: 192 bytes)


In [72]:
@time bdot2(A5, x);

  0.000007 seconds (2 allocations: 192 bytes)


In [73]:
@time bdot3(A5, x);

  0.000007 seconds (6 allocations: 640 bytes)


In [74]:
using BenchmarkTools

In [75]:
@btime bdot($A5, $x);

  210.359 ns (2 allocations: 192 bytes)


In [76]:
@btime bdot2($A5, $x);

  824.557 ns (2 allocations: 192 bytes)


In [77]:
@btime bdot3($A5, $x);

  296.067 ns (6 allocations: 640 bytes)
