Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using Documenter, Literate
using Pkg

# Literate Examples
const DRAFT = get(ENV, "DRAFT", "false") == "true"
@show DRAFT
const EXAMPLE_DIR = pkgdir(TensorInference, "examples")
const LITERATE_GENERATED_DIR = pkgdir(TensorInference, "docs", "src", "generated")
mkpath(LITERATE_GENERATED_DIR)
Expand All @@ -19,7 +21,7 @@ for each in readdir(EXAMPLE_DIR)
# build
input_file = joinpath(workdir, "main.jl")
@info "building" input_file
Literate.markdown(input_file, workdir; execute=true)
Literate.markdown(input_file, workdir; execute=!DRAFT)
# restore environment
# Pkg.activate(Pkg.PREV_ENV_PATH[])
end
Expand All @@ -30,7 +32,7 @@ for each in EXTRA_JL
cp(joinpath(SRC_DIR, each), joinpath(LITERATE_GENERATED_DIR, each); force=true)
input_file = joinpath(LITERATE_GENERATED_DIR, each)
@info "building" input_file
Literate.markdown(input_file, LITERATE_GENERATED_DIR; execute=true)
Literate.markdown(input_file, LITERATE_GENERATED_DIR; execute=!DRAFT)
end

DocMeta.setdocmeta!(TensorInference, :DocTestSetup, :(using TensorInference); recursive=true)
Expand Down Expand Up @@ -68,6 +70,7 @@ makedocs(;
],
doctest = false,
warnonly = :missing_docs,
draft = DRAFT,
)

deploydocs(;
Expand Down
6 changes: 6 additions & 0 deletions src/RescaledArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ Base.show(io::IO, ::MIME"text/plain", c::RescaledArray) = Base.show(io, c)
Base.Array(c::RescaledArray) = rmul!(Array(c.normalized_value), exp(c.log_factor))
Base.copy(c::RescaledArray) = RescaledArray(c.log_factor, copy(c.normalized_value))
Base.getindex(r::RescaledArray, indices...) = map(x->x * exp(r.log_factor), getindex(r.normalized_value, indices...))
Base.similar(r::RescaledArray, ::Type{T}, dims::Dims) where {T} = RescaledArray(r.log_factor, similar(r.normalized_value, T, dims))
Base.selectdim(r::RescaledArray, d::Int, i::Int) = RescaledArray(r.log_factor, selectdim(r.normalized_value, d, i))
function Base.copyto!(dest::RescaledArray, src::RescaledArray)
dest.normalized_value .= exp(src.log_factor - dest.log_factor) .* src.normalized_value
return dest
end

"""
$(TYPEDSIGNATURES)
Expand Down
38 changes: 27 additions & 11 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
@assert length(ix) == N
return x[eliminated_selector(size(x), ix, el.first, el.second)...]
end
function eliminate_dimensions(x::RescaledArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}, <:AbstractVector}) where {T, N, L}
return RescaledArray(x.log_factor, eliminate_dimensions(x.normalized_value, ix, el))
end

function eliminated_size(size0, ix, labels)
@assert length(size0) == length(ix)
return ntuple(length(ix)) do i
Expand All @@ -53,7 +57,7 @@ function eliminate_dimensions_addbatch!(x::AbstractArray{T, N}, ix::AbstractVect
@assert length(ix) == N
res = similar(x, (eliminated_size(size(x), ix, el.first)..., nbatch))
for ibatch in 1:nbatch
selectdim(res, N+1, ibatch) .= eliminate_dimensions(x, ix, el.first=>view(el.second, :, ibatch))
copyto!(selectdim(res, N+1, ibatch), eliminate_dimensions(x, ix, el.first=>view(el.second, :, ibatch)))
end
push!(ix, batch_label)
return res
Expand All @@ -63,7 +67,7 @@ function eliminate_dimensions_withbatch(x::AbstractArray{T, N}, ix::AbstractVect
@assert length(ix) == N && size(x, N) == nbatch
res = similar(x, (eliminated_size(size(x), ix, el.first)))
for ibatch in 1:nbatch
selectdim(res, N, ibatch) .= eliminate_dimensions(selectdim(x, N, ibatch), ix[1:end-1], el.first=>view(el.second, :, ibatch))
copyto!(selectdim(res, N, ibatch), eliminate_dimensions(selectdim(x, N, ibatch), ix[1:end-1], el.first=>view(el.second, :, ibatch)))
end
return res
end
Expand All @@ -79,28 +83,28 @@ Returns a vector of vector, each element being a configurations defined on `get_
* `n` is the number of samples to be returned.

### Keyword Arguments
* `rescale` is a boolean flag to indicate whether to rescale the tensors during contraction.
* `usecuda` is a boolean flag to indicate whether to use CUDA for tensor computation.
* `queryvars` is the variables to be sampled, default is `get_vars(tn)`.
"""
function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get_vars(tn))::Samples
function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get_vars(tn), rescale::Bool = false)::Samples
# generate tropical tensors with its elements being log(p).
xs = adapt_tensors(tn; usecuda, rescale = false)
xs = adapt_tensors(tn; usecuda, rescale)
# infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
size_dict = OMEinsum.get_size_dict!(getixsv(tn.code), xs, Dict{Int, Int}())
# forward compute and cache intermediate results.
cache = cached_einsum(tn.code, xs, size_dict)
# initialize `y̅` as the initial batch of samples.
iy = getiyv(tn.code)
idx = map(l->findfirst(==(l), queryvars), iy ∩ queryvars)
indices = StatsBase.sample(CartesianIndices(size(cache.content)), _Weights(vec(cache.content)), n)
indices = StatsBase.sample(CartesianIndices(size(cache.content)), _weight(cache.content), n)
configs = zeros(Int, length(queryvars), n)
for i=1:n
configs[idx, i] .= indices[i].I .- 1
end
samples = Samples(configs, queryvars)
# back-propagate
env = similar(cache.content, (size(cache.content)..., n)) # batched env
fill!(env, one(eltype(env)))
env = ones_like(cache.content, n)
batch_label = _newindex(OMEinsum.uniquelabels(tn.code))
code = deepcopy(tn.code)
iy_env = [OMEinsum.getiyv(code)..., batch_label]
Expand All @@ -115,10 +119,22 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
end
_newindex(labels::AbstractVector{<:Union{Int, Char}}) = maximum(labels) + 1
_newindex(::AbstractVector{Symbol}) = gensym(:batch)
_Weights(x::AbstractVector{<:Real}) = Weights(x)
function _Weights(x::AbstractArray{<:Complex})
_weight(x::AbstractArray{<:Real}) = Weights(_normvec(x))
function _weight(_x::AbstractArray{<:Complex})
x = _normvec(_x)
@assert all(e->abs(imag(e)) < max(100*eps(abs(e)), 1e-8), x) "Complex probability encountered: $x"
return Weights(real.(x))
return _weight(real.(x))
end
_normvec(x::AbstractArray) = vec(x)
_normvec(x::RescaledArray) = vec(x.normalized_value)

function ones_like(x::AbstractArray{T}, n::Int) where {T}
res = similar(x, (size(x)..., n))
fill!(res, one(eltype(res)))
return res
end
function ones_like(x::RescaledArray, n::Int)
return RescaledArray(zero(x.log_factor), ones_like(x.normalized_value, n))
end

function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, iy_env::Vector{Int}, env::AbstractArray{T}, samples::Samples{L}, pool, batch_label::L, size_dict::Dict{L}) where {T, L}
Expand Down Expand Up @@ -177,7 +193,7 @@ function update_samples!(labels, sample, vars::AbstractVector{L}, probabilities:
@assert length(vars) == N
totalset = CartesianIndices(probabilities)
eliminated_locs = idx4labels(labels, vars)
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
config = StatsBase.sample(totalset, _weight(probabilities))
sample[eliminated_locs] .= config.I .- 1
end

Expand Down
152 changes: 152 additions & 0 deletions test/RescaledArray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using Test
using TensorInference
using OMEinsum

@testset "RescaledArray" begin
# Test basic construction
@testset "Construction" begin
α = 2.0
T = [1.0 2.0; 3.0 4.0]
r = RescaledArray(α, T)

@test r.log_factor == α
@test r.normalized_value == T
@test size(r) == (2, 2)
@test size(r, 1) == 2
@test size(r, 2) == 2
end

# Test rescale_array function
@testset "rescale_array" begin
T = [1.0 2.0; 3.0 4.0]
r = TensorInference.rescale_array(T)

# Maximum absolute value should be 1 in normalized_value
@test maximum(abs, r.normalized_value) ≈ 1.0

# Original array should be recoverable
@test Array(r) ≈ T

# Test with zero array
zero_T = zeros(2, 2)
r_zero = TensorInference.rescale_array(zero_T)
@test r_zero.log_factor == 0.0
@test r_zero.normalized_value == zero_T
end

# Test Array conversion
@testset "Array conversion" begin
α = 1.5
T = [0.5 1.0; 0.25 0.75]
r = RescaledArray(α, T)

expected = exp(α) * T
@test Array(r) ≈ expected
end

# Test indexing
@testset "Indexing" begin
α = 0.5
T = [1.0 2.0; 3.0 4.0]
r = RescaledArray(α, T)

@test r[1, 1] ≈ T[1, 1] * exp(α)
@test r[2, 2] ≈ T[2, 2] * exp(α)
@test r[1:2, 1] ≈ T[1:2, 1] * exp(α)
end

# Test copy
@testset "Copy" begin
α = 1.0
T = [1.0 2.0; 3.0 4.0]
r = RescaledArray(α, T)
r_copy = copy(r)

@test r_copy.log_factor == r.log_factor
@test r_copy.normalized_value == r.normalized_value
@test r_copy.normalized_value !== r.normalized_value # Different objects
end

# Test selectdim
@testset "selectdim" begin
T = reshape(Float64.(1:8), 2, 2, 2) # Convert to Float64 to match log factor type
α = 0.5
r = RescaledArray(α, T)

r_slice = selectdim(r, 3, 1)
@test r_slice.log_factor == α
@test r_slice.normalized_value == selectdim(T, 3, 1)
end

# Test einsum operations
@testset "Einsum operations" begin
# Create two rescaled arrays
α1, α2 = 1.0, 1.5
T1 = [1.0 0.5; 0.25 1.0]
T2 = [0.5 1.0; 1.0 0.5]

r1 = RescaledArray(α1, T1)
r2 = RescaledArray(α2, T2)

# Test matrix multiplication via einsum
code = ein"ij,jk->ik"
result = einsum(code, (r1, r2))

# Compare with regular array multiplication
expected_array = Array(r1) * Array(r2)
@test Array(result) ≈ expected_array

# The log factor should be the sum of input log factors plus rescaling
@test result isa RescaledArray
end

# Test fill! and conj
@testset "fill! and conj" begin
α = 0.5
T = [1.0 2.0; 3.0 4.0]
r = RescaledArray(α, copy(T))

# Test fill!
fill!(r, 2.0)
expected_fill_value = 2.0 / exp(α)
@test all(x -> x ≈ expected_fill_value, r.normalized_value)

# Test conj with complex numbers
α_complex = 1.0 + 0.5im
T_complex = [1.0+1.0im 2.0+2.0im; 3.0+3.0im 4.0+4.0im]
r_complex = RescaledArray(α_complex, T_complex)
r_conj = conj(r_complex)

@test r_conj.log_factor == conj(α_complex)
@test r_conj.normalized_value == conj(T_complex)
end

# Test show methods
@testset "Display" begin
α = 1.0
T = [1.0 2.0]
r = RescaledArray(α, T)

# Test that show methods don't error
@test sprint(show, r) isa String
@test sprint(show, "text/plain", r) isa String
end

# Test copyto!
@testset "copyto!" begin
α = 2.0
T = [1.0 2.0; 3.0 4.0]
r = RescaledArray(α, T)
r_copy = similar(r)
copyto!(r_copy, r)
@test Array(r_copy) ≈ Array(r)

α = 2.0
T = [1.0 2.0; 3.0 4.0]
r = RescaledArray(α, T)
r_copy = similar(r)
copyto!(selectdim(r_copy, 1, 1), selectdim(r, 1, 1))
@test Array(r_copy)[1, :] ≈ Array(r)[1, :]
@test !(Array(r_copy)[2, :] ≈ Array(r)[2, :])
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ end
include("fileio.jl")
end

@testset "RescaledArray" begin
include("RescaledArray.jl")
end

using CUDA
if CUDA.functional()
include("cuda.jl")
Expand Down
10 changes: 10 additions & 0 deletions test/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,14 @@ end
entropy(probs) = -sum(probs .* log.(probs))
@show negative_loglikelyhood(probs, indices), entropy(probs)
@test negative_loglikelyhood(probs, indices) ≈ entropy(probs) atol=1e-1
end

@testset "issue 102 - support using rescaled array in sampling" begin
n = 100
chi = 10
Random.seed!(140)
mps = random_matrix_product_state(Float64, n, chi)
mps.tensors[setdiff(1:length(mps.tensors), mps.unity_tensors_idx)] .*= 100
samples = sample(mps, 1; rescale = true)
@test samples isa TensorInference.Samples
end
Loading