In [None]:
#### Include
using Plots, LightGraphs, SparseArrays, SimpleWeightedGraphs
using Statistics, BenchmarkTools, LinearAlgebra, ProgressMeter
using Distributions, Base.Threads, CSV, StatsBase
using Base.GC, JLD2, FileIO, Random
plotly();

In [None]:
#### Get data from file
function read_data(filename)
    lines = readlines(open(filename))
    num_variables = 0
    for i = 1:length(lines)
        if lines[i][1] != '@'
            num_variables = i - 6
            break
        end
    end
    
    @show(num_variables)
    
    X = zeros(num_variables,length(lines)-num_variables-6)
    y = zeros(length(lines)-num_variables-6)
    
    s = num_variables + 6
    classes = Dict()
    num_classes = 0
    
    for i = s:length(lines)-1
        line = split(lines[i],",")
        for j = 1:num_variables
            X[j,i-s+1] = parse(Float64,line[j])
        end
        if haskey(classes,line[end])
            y[i-s+1] = classes[line[end]]
        else
            classes[line[end]] = num_classes+1
            num_classes += 1
            y[i-s+1] = classes[line[end]]
        end
        #y[i-s+1] = parse(Float64,line[end])+1
    end
    
    return X,y,classes
end


In [None]:
#### read in the data, set up X, Y, and what the labels are
#### split into train and test (random 20% of the data)

X,y,labels = read_data("./ITML Data/banana.dat")
num_classes = length(Set(y))
n = size(X,2)
d = size(X,1)
t = Int(floor(0.8*n))
p = randperm(n)
Xtrn = X[:,p[1:t]]
Xtst = X[:,p[t+1:end]]
ytrn = y[p[1:t]]
ytst = y[p[t+1:end]];

In [None]:
### Make similar and different pairs based on the class labels
### So, in fact, we *do* have ground truth here

function make_SD(X,y)
    n = length(y)
    S = []
    D = []
    
    for i = 1:n
        for j = 1:i-1
            if y[i] == y[j]
                push!(S,(i,j))
            else
                push!(D,(i,j))
            end
        end
    end
    
    return S,D
end

In [None]:
S,D = make_SD(Xtrn,ytrn)

In [None]:
function itml(X,S,D,u,l,A₀,γ)
    A = copy(A₀)
    _,n = size(X)
    λ = spzeros(n,n)
    ξ = l*ones(n,n)
    
    @show(n)
    flush(stdout)
    
    for (i,j) in S
        ξ[i,j] = u
    end
    
    N = 20*num_classes^2
    Constraints = Dict() ### Dict() is a look-up table
    while(length(Constraints) < N)
        i = rand(1:n)
        j = rand(1:n)
        if i != j
            a = min(i,j)
            b = max(i,j)
            if y[b] == y[a]
                Constraints[(b,a)] = 1
            else
                Constraints[(b,a)] = 2 
            end
        end
    end
    
    @show(length(Constraints))
    flush(stdout)
    
    e = 1e-30
    for k = 1:1e2
        for (i,j) in keys(Constraints)
            if Constraints[(i,j)] == 1
                p = (X[:,i]-X[:,j])'*A*(X[:,i]-X[:,j])
                δ = 1.0
                α = min(λ[i,j],δ/2*(1/(p+e)-γ/(e+ξ[i,j])))
                β = δ*α/(1-δ*α*p+e)
            
                ξ[i,j] = γ*ξ[i,j]/(γ+δ*α*ξ[i,j]+e)
                ξ[j,i] = ξ[i,j]
        
                λ[i,j] -= α
                λ[j,i] = λ[i,j]
        
                A += β*A*(X[:,i]-X[:,j])*(X[:,i]-X[:,j])'*A
            else
                p = (X[:,i]-X[:,j])'*A*(X[:,i]-X[:,j])
                δ = -1.0
                α = min(λ[i,j],δ/2*(1/(e+p)-γ/(e+ξ[i,j])))
                β = δ*α/(1-δ*α*p+e)
        
                ξ[i,j] = γ*ξ[i,j]/(γ+δ*α*ξ[i,j]+e)
                ξ[j,i] = ξ[i,j]
        
                λ[i,j] -= α
                λ[j,i] = λ[i,j]
        
                A += β*A*(X[:,i]-X[:,j])*(X[:,i]-X[:,j])'*A
            end
        end
        
        if k %1e2 == 0
            classify(A,5)
        end
    end
    
    return A
end

In [None]:

function eig_proj(A)
    F = eigen((A'+A)/2, permute=false)
    p = F.values .> 0
    return F.vectors *Diagonal(p.*F.values)*F.vectors'
end

function classify(A,K)
    A = eig_proj(A)
    n = size(Xtst,2)
    ypred = zeros(n)
    for i = 1:n
        x = Xtst[:,i]
        dists = zeros(size(Xtrn,2))
        for j = 1:size(Xtrn,2)
            dists[j] = sqrt((Xtrn[:,j]-x)'*A*(Xtrn[:,j]-x))
        end
        p = sortperm(dists)
        ngbd = ytrn[p[1:K]]
        Count = zeros(num_classes)
        for i = 1:length(ngbd)
            Count[Int(ngbd[i])] += 1
        end
        ypred[i] = findmax(Count)[2]
    end
    
    @show(sum(ypred .== ytst)/length(ytst))
    flush(stdout)
    
    return sum(ypred .== ytst)/length(ytst)
end

In [None]:
num_classes

In [None]:
length(y)

In [None]:
Sys.free_memory()/2^(30)

In [None]:
Sys.free_memory()/2^(30)

In [None]:
#n = length(y)
A₀ = Matrix(I,d,d)
u = 1
l = 10
γ = 1

In [None]:
@time A₁ = stochastic_itml(Xtrn,S,D,u,l,Matrix(I,d,d),γ,1e8,5)

In [None]:
length(ytst)

In [None]:
avg_acc = zeros(100);

In [None]:
for i=1:100
    A₂ = itml(Xtrn,S,D,u,l,Matrix(I,d,d),γ)
    avg_acc[i] = classify(A₂,5)
end

@show(mean(avg_acc))

In [None]:
sum(avg_acc)/75

In [None]:
A = sparse(rand(1:5,5))

In [None]:
nz(A)

In [None]:
size(Xtrn,2)

In [None]:
size(X)