Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError in JETAnalyzer #271

Closed
goerch opened this issue Nov 29, 2021 · 0 comments
Closed

TypeError in JETAnalyzer #271

goerch opened this issue Nov 29, 2021 · 0 comments

Comments

@goerch
Copy link
Contributor

goerch commented Nov 29, 2021

When checking this MWE

using Flux
using Dates

channels = 64
in_channels = 3
num_channels = [3,64,64,128,128,256,256,512,512]

# discriminator

function block((in_channels,out_channels); stride = 1, use_batch_norm = true)
    layers = []
    push!(layers, Conv((3,3),in_channels => out_channels,pad = 1,stride = stride))
    if use_batch_norm
        push!(layers, BatchNorm(out_channels))
    end
    #push!(layers, x -> leakyrelu.(x,0.2f0))
    return layers
end

discriminator = Chain(
    reduce(vcat,[block(num_channels[i] => num_channels[i+1];
            stride = 1 + (i+1) % 2,
            use_batch_norm = i!=1) for i = 1:length(num_channels)-1 ])...,
    AdaptiveMeanPool((1,1)),
    Conv((1,1), num_channels[end] => 1024),
    #x -> leakyrelu.(x,0.2f0),
    Conv((1,1), 1024 => 1),
) |> gpu


function resblock(channels)
    return SkipConnection(Chain(
            Conv((3,3),channels => channels, pad=1),
            BatchNorm(channels),
            #Prelu(),
            Conv((3,3),channels => channels, pad=1),
            BatchNorm(channels),
        )
        , +)
end


function upsample(in_channels, up_scale)
    return [
        Conv((3,3),in_channels => in_channels*up_scale^2,pad=1),
        PixelShuffle(up_scale),
        #Prelu(),
    ]
end

generator = Chain(
    Conv((9,9),3 => channels, pad = 4),
    #Prelu(),

    SkipConnection(Chain(
        # test with different number of residual blocks
        resblock(channels),
        resblock(channels),
        resblock(channels),
        resblock(channels),
        resblock(channels),
        Conv((3,3),channels => channels, pad=1),
        BatchNorm(channels)),+),
    upsample(channels, 2)...,
    upsample(channels, 2)...,
    Conv((9,9),channels => 3,σ, pad=4),
) |> gpu;

hr_images = randn(Float32,88,88,3,32)
lr_images = randn(Float32,22,22,3,32)

hr_images = gpu(hr_images)
lr_images = gpu(lr_images)

# check foreward
@show sum(discriminator(generator(lr_images)))

params_g = Flux.params(generator)

@info "generator $(Dates.now())"

# Taking gradient of generator
loss_g, back = @time Flux.pullback(params_g) do
    sum(discriminator(generator(lr_images)))
end

with report_file I got

ERROR: TypeError: in typeassert, expected Vector{Any}, got a value of type Vector{Core.MethodInstance}
Stacktrace:
  [1] finish(me::Core.Compiler.InferenceState, analyzer::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}})
    @ JET C:\Users\Win10\.julia\packages\JET\iloOP\src\abstractinterpret\typeinfer.jl:703
  [2] _typeinf(analyzer::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, frame::Core.Compiler.InferenceState)
    @ JET C:\Users\Win10\.julia\packages\JET\iloOP\src\abstractinterpret\typeinfer.jl:613
  [3] typeinf(interp::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, frame::Core.Compiler.InferenceState)
    @ Core.Compiler .\compiler\typeinfer.jl:209
  [4] typeinf(analyzer::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, frame::Core.Compiler.InferenceState)
    @ JET C:\Users\Win10\.julia\packages\JET\iloOP\src\abstractinterpret\typeinfer.jl:528
  [5] typeinf_edge(interp::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, method::Method, atypes::Any, sparams::Core.SimpleVector, caller::Core.Compiler.InferenceState)
    @ Core.Compiler .\compiler\typeinfer.jl:823
  [6] typeinf_edge
    @ C:\Users\Win10\.julia\packages\JET\iloOP\src\abstractinterpret\typeinfer.jl:347 [inlined]
  [7] abstract_call_method(interp::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, method::Method, sig::Any, sparams::Core.SimpleVector, hardlimit::Bool, sv::Core.Compiler.InferenceState)
    @ Core.Compiler .\compiler\abstractinterpretation.jl:504
  [8] abstract_call_method
    @ C:\Users\Win10\.julia\packages\JET\iloOP\src\abstractinterpret\typeinfer.jl:170 [inlined]
  [9] abstract_call_gf_by_type(interp::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, f::Any, fargs::Vector{Any}, argtypes::Vector{Any}, atype::Any, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler .\compiler\abstractinterpretation.jl:105
 [10] abstract_call_gf_by_type
    @ C:\Users\Win10\.julia\packages\JET\iloOP\src\analyzers\jetanalyzer.jl:331 [inlined]
 [11] abstract_call_known(interp::JET.JETAnalyzer{JET.BasicPass{typeof(JET.basic_function_filter)}}, f::Any, fargs::Vector{Any}, argtypes::Vector{Any}, sv::Core.Compiler.InferenceState, max_methods::Int64)
    @ Core.Compiler .\compiler\abstractinterpretation.jl:1339

on Julia 1.7.0-rc3 and 1.8.0-DEV.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant