Skip to content

Fit at the beginning of an array takes 20x more iterations than the rest #146

@svilupp

Description

@svilupp

As discussed, there is an odd artefact where the fit gets great almost immediately across the model except for the beginning, which then takes up to 20x iterations to converge (linked more to the array size).

Result with 20 iterations (observe the beginning):
image

Result with 365 iterations (=N)
image

MWE for a simple curve-fitting scenario

# Generate data to fit
time_max=365
noise_scale=0.05

time_index=collect(1:time_max)/time_max
p=180/time_max
growth=2
offset=0
y=offset .+ sin.(time_index*2π/p) .+ growth.*time_index .+rand(Normal(0,noise_scale),time_max)
plot(y)

# Generate splines to approximate the function
# Note: boundary knots are important to be outside of the needed range to avoid a row of all zeros (which breaks the backprop)
X=Splines2.bs(time_index,df=10,boundary_knots=(-0.01,1.01));

# Build the model
@model function linreg(X,n,dim_x)

    T=Float64
    
    y = datavar(T, n)
    aux = randomvar(n)

    sigma  ~ GammaShapeRate(1.0, 1.0)
    intercept ~ NormalMeanVariance(0.0, 2.0)
    beta  ~ MvNormalMeanPrecision(zeros(dim_x), diageye(dim_x))
    
    for i in 1:n
        aux[i] ~ intercept + dot(X[i,:], beta)
        y[i] ~ NormalMeanPrecision(aux[i], sigma) 
    end

    return beta,aux,y
end

constraints = @constraints begin 
    q(aux, sigma) = q(aux)q(sigma)
end

# Run inference
@time results = inference(
    model = Model(linreg,X,size(X)...),
    data  = (y = y,),
    constraints = constraints,
    initmessages = (intercept = vague(NormalMeanVariance),),
    initmarginals = (sigma = GammaShapeRate(1.0, 1.0),),
    returnvars   = (sigma = KeepLast(),beta = KeepLast(), aux = KeepLast()),#,y=KeepLast()),
    iterations   = 20,
    warn = true,
    free_energy=true
)

# Plot results
# Note: observe the divergence in the first 50 data points
# It disappears as you increase number of iterations
plot(mean.(results.posteriors[:aux]), ribbon = (results.posteriors[:sigma]|>mean|>inv|>sqrt),label="Fitted")
plot!(y,label="Observed data")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions