In [1]:
using CSV
using DataFrames
using JLD2
using LinearAlgebra
using ProgressMeter
using Random
using SparseArrays
using StatsBase

In [2]:
Random.seed!(1234)

TaskLocalRNG()

In [3]:
df_neurons = CSV.read("./data/ol_columns.csv", DataFrame; header=true);

In [4]:
N = nrow(df_neurons)
K = length(unique(df_neurons[!, "column id"])) - 1
T = length(unique(df_neurons[!, "cell type"]))
N, K, T

(23452, 796, 31)

In [5]:
d_neuron = Dict(df_neurons[!, "cell id"] .=> 1:N)
d_type = Dict(unique(df_neurons[!, "cell type"]) .=> 1:T);

In [6]:
@time begin
    X = zeros(Int8, N, K)
    Y = zeros(Int8, N, T)

    for idx in 1:N
        row = df_neurons[idx, :]
        Y[idx, d_type[row["cell type"]]] = 1
        if row["column id"] != "not assigned"
            col = parse(Int, row["column id"])
            X[idx, col] = 1
        end
    end
end

  0.126217 seconds (436.58 k allocations: 32.506 MiB, 82.39% compilation time)


In [7]:
df_conn = CSV.read("./data/ol_connections.csv", DataFrame);

In [8]:
@time begin
    W = zeros(Int16, N, N)

    for row in eachrow(df_conn)
        i = row["from cell id"]
        j = row["to cell id"]
        w = row["synapses"]
        W[d_neuron[i], d_neuron[j]] = w
    end
end

  0.288091 seconds (3.80 M allocations: 1.106 GiB, 4.34% gc time, 5.13% compilation time)


In [9]:
function f(X, W)
    X_sparse = sparse(X)
    W_sparse = sparse(W)
    return sum(W_sparse .* (X_sparse * X_sparse'))
end

f (generic function with 1 method)

In [10]:
orig_f = f(X, W)
curr_f = orig_f

1381915

In [11]:
types_of_interest = ["R7", "R8", "T2a", "T3", "Tm3"]
type_list = [d_type[cell_type] for cell_type in types_of_interest]

5-element Vector{Int64}:
 29
 31
 14
 13
 12

In [12]:
function find_column(X, idx)
    # Find the column of a given neuron
    col_list = findall(X[idx, :] .!= 0)
    if length(col_list) == 1
        return col_list[1]
    elseif isempty(col_list)
        return NaN
    else
        error("Something wrong with column assignment")
    end
end

function find_colmates(X, k)
    # Find all the neurons in a given column k ("column mates")
    if isnan(k)
        return Int8[]
    else
        colmates = findall(X[:, k] .== 1)
        return colmates
    end
end

function update_colmates(colmates, idx_remove, idx_add)
    # Update the colmates after swap
    if isempty(colmates)
        return Int8[]
    end
    s_1 = Set(colmates)
    s_2 = Set([idx_remove])
    s_3 = Set([idx_add])
    s = setdiff(s_1, s_2) ∪ s_3
    return collect(s)
end

function delta_f(X, W, idx_1, idx_2)
    # Calculate the difference in objective function after the swap
    col_1 = find_column(X, idx_1)
    col_2 = find_column(X, idx_2)
    
    # Pre swap
    colmates_1_pre = find_colmates(X, col_1)
    colmates_2_pre = find_colmates(X, col_2)
    loss_1 = sum(W[idx_1, colmates_1_pre]) + sum(W[colmates_1_pre, idx_1]) - W[idx_1, idx_1]
    loss_2 = sum(W[idx_2, colmates_2_pre]) + sum(W[colmates_2_pre, idx_2]) - W[idx_2, idx_2]
    
    # Post swap
    colmates_1_post = update_colmates(colmates_2_pre, idx_2, idx_1)
    colmates_2_post = update_colmates(colmates_1_pre, idx_1, idx_2)
    gain_1 = sum(W[idx_2, colmates_2_post]) + sum(W[colmates_2_post, idx_2]) - W[idx_2, idx_2]
    gain_2 = sum(W[idx_1, colmates_1_post]) + sum(W[colmates_1_post, idx_1]) - W[idx_1, idx_1]
    
    d_f = (gain_1 + gain_2) - (loss_1 + loss_2)
    return d_f
end

delta_f (generic function with 1 method)

In [13]:
function swap_rows!(X, idx_1, idx_2)
    X[idx_1, :], X[idx_2, :] = X[idx_2, :], X[idx_1, :]
end

swap_rows! (generic function with 1 method)

In [14]:
@time begin
    swap_list = []

    @showprogress for trial in 1:100000000
        # t = sample(1:T, 1)
        t = sample(type_list, 1)
        candidate_neurons = findall(dropdims(Y[:, t] .== 1, dims=2))
        idx_1, idx_2 = sample(candidate_neurons, 2, replace=false)

        d_f = delta_f(X, W, idx_1, idx_2)
        if d_f > 0
            swap_rows!(X, idx_1, idx_2)
            curr_f += d_f
            d_f = curr_f - orig_f
            println("Trial: $trial, Curr_f: $curr_f, d_f: $d_f")
            push!(swap_list, (idx_1, idx_2))

            file_path = joinpath("../results", "swap_list_$(Int(curr_f)).jld2")
            @save file_path swap_list
        end
    end
end


[32mProgress:   0%|                                         |  ETA: 2:51:53[39m[39m

Trial: 48347, Curr_f: 1381927, d_f: 12


[32mProgress:   0%|                                         |  ETA: 3:02:21[39m

Trial: 68378, Curr_f: 1381943, d_f: 28


[32mProgress:   0%|                                         |  ETA: 2:32:57[39m

Trial: 146229, Curr_f: 1381947, d_f: 32


[32mProgress:   0%|                                         |  ETA: 2:31:48[39m

Trial: 152555, Curr_f: 1382017, d_f: 102


[32mProgress:   0%|▏                                        |  ETA: 2:31:14[39m

Trial: 156735, Curr_f: 1382033, d_f: 118


[32mProgress:   0%|▏                                        |  ETA: 2:26:13[39m

Trial: 200656, Curr_f: 1382052, d_f: 137


[32mProgress:   0%|▏                                        |  ETA: 2:20:36[39m

Trial: 273265, Curr_f: 1382063, d_f: 148


[32mProgress:   0%|▏                                        |  ETA: 2:20:23[39m

Trial: 281034, Curr_f: 1382074, d_f: 159


[32mProgress:   0%|▏                                        |  ETA: 2:19:06[39m

Trial: 312971, Curr_f: 1382083, d_f: 168


[32mProgress:   0%|▏                                        |  ETA: 2:16:22[39m

Trial: 388231, Curr_f: 1382110, d_f: 195


[32mProgress:   0%|▏                                        |  ETA: 2:15:00[39m

Trial: 453952, Curr_f: 1382131, d_f: 216


[32mProgress:   0%|▎                                        |  ETA: 2:14:25[39m

Trial: 480235, Curr_f: 1382135, d_f: 220


[32mProgress:   1%|▎                                        |  ETA: 2:13:04[39m

Trial: 562079, Curr_f: 1382141, d_f: 226


[32mProgress:   1%|▎                                        |  ETA: 2:12:33[39m

Trial: 632961, Curr_f: 1382145, d_f: 230


[32mProgress:   1%|▎                                        |  ETA: 2:12:01[39m

Trial: 669690, Curr_f: 1382147, d_f: 232


[32mProgress:   1%|▎                                        |  ETA: 2:11:50[39m

Trial: 679852, Curr_f: 1382155, d_f: 240


[32mProgress:   1%|▍                                        |  ETA: 2:10:57[39m

Trial: 772022, Curr_f: 1382163, d_f: 248


[32mProgress:   1%|▍                                        |  ETA: 2:10:48[39m

Trial: 817549, Curr_f: 1382167, d_f: 252


[32mProgress:   1%|▍                                        |  ETA: 2:10:43[39m

Trial: 879304, Curr_f: 1382171, d_f: 256


[32mProgress:   1%|▍                                        |  ETA: 2:10:43[39m

Trial: 896284, Curr_f: 1382179, d_f: 264


[32mProgress:   1%|▍                                        |  ETA: 2:10:38[39m

Trial: 909618, Curr_f: 1382182, d_f: 267


[32mProgress:   1%|▍                                        |  ETA: 2:10:26[39m

Trial: 958943, Curr_f: 1382185, d_f: 270


[32mProgress:   1%|▍                                        |  ETA: 2:10:12[39m

Trial: 1056953, Curr_f: 1382186, d_f: 271


[32mProgress:   1%|▌                                        |  ETA: 2:09:51[39m

Trial: 1146012, Curr_f: 1382207, d_f: 292


[32mProgress:   1%|▌                                        |  ETA: 2:09:42[39m

Trial: 1166073, Curr_f: 1382209, d_f: 294


[32mProgress:   1%|▌                                        |  ETA: 2:09:37[39m

Trial: 1180613, Curr_f: 1382222, d_f: 307


[32mProgress:   1%|▌                                        |  ETA: 2:09:22[39m

Trial: 1219579, Curr_f: 1382248, d_f: 333


[32mProgress:   1%|▌                                        |  ETA: 2:09:02[39m

Trial: 1268990, Curr_f: 1382279, d_f: 364


[32mProgress:   1%|▋                                        |  ETA: 2:08:35[39m

Trial: 1419200, Curr_f: 1382289, d_f: 374


[32mProgress:   2%|▋                                        |  ETA: 2:08:24[39m

Trial: 1527617, Curr_f: 1382297, d_f: 382


[32mProgress:   2%|▋                                        |  ETA: 2:08:24[39m

Trial: 1569133, Curr_f: 1382300, d_f: 385


[32mProgress:   2%|▋                                        |  ETA: 2:08:20[39m

Trial: 1636381, Curr_f: 1382307, d_f: 392


[32mProgress:   2%|▋                                        |  ETA: 2:08:22[39m

Trial: 1655062, Curr_f: 1382327, d_f: 412


[32mProgress:   2%|▊                                        |  ETA: 2:08:40[39m

Trial: 1744339, Curr_f: 1382333, d_f: 418


[32mProgress:   2%|▊                                        |  ETA: 2:08:51[39m

Trial: 1806609, Curr_f: 1382346, d_f: 431


[32mProgress:   2%|▉                                        |  ETA: 2:08:56[39m

Trial: 1989722, Curr_f: 1382355, d_f: 440


[32mProgress:   2%|▉                                        |  ETA: 2:08:52[39m

Trial: 2014117, Curr_f: 1382360, d_f: 445


[32mProgress:   2%|▉                                        |  ETA: 2:08:50[39m

Trial: 2067887, Curr_f: 1382367, d_f: 452


[32mProgress:   2%|▉                                        |  ETA: 2:09:01[39m

Trial: 2098950, Curr_f: 1382372, d_f: 457


[32mProgress:   2%|▉                                        |  ETA: 2:09:39[39m

Trial: 2142353, Curr_f: 1382378, d_f: 463


[32mProgress:   2%|█                                        |  ETA: 2:09:16[39m

Trial: 2308527, Curr_f: 1382387, d_f: 472


[32mProgress:   2%|█                                        |  ETA: 2:08:44[39m

Trial: 2455315, Curr_f: 1382399, d_f: 484


[32mProgress:   3%|█                                        |  ETA: 2:08:25[39m

Trial: 2547559, Curr_f: 1382407, d_f: 492


[32mProgress:   3%|█▏                                       |  ETA: 2:08:05[39m

Trial: 2659796, Curr_f: 1382411, d_f: 496


[32mProgress:   3%|█▏                                       |  ETA: 2:07:56[39m

Trial: 2714388, Curr_f: 1382415, d_f: 500


[32mProgress:   3%|█▏                                       |  ETA: 2:07:46[39m

Trial: 2799638, Curr_f: 1382430, d_f: 515


[32mProgress:   3%|█▎                                       |  ETA: 2:06:50[39m

Trial: 3087132, Curr_f: 1382441, d_f: 526


[32mProgress:   3%|█▍                                       |  ETA: 2:06:10[39m

Trial: 3204974, Curr_f: 1382444, d_f: 529


[32mProgress:   3%|█▍                                       |  ETA: 2:04:57[39m

Trial: 3417904, Curr_f: 1382445, d_f: 530


[32mProgress:   4%|█▌                                       |  ETA: 2:04:20[39m

Trial: 3546079, Curr_f: 1382455, d_f: 540


[32mProgress:   4%|█▌                                       |  ETA: 2:03:41[39m

Trial: 3685377, Curr_f: 1382465, d_f: 550


[32mProgress:   4%|█▋                                       |  ETA: 2:02:13[39m

Trial: 4013538, Curr_f: 1382466, d_f: 551


[32mProgress:   4%|█▊                                       |  ETA: 2:00:57[39m

Trial: 4313799, Curr_f: 1382477, d_f: 562


[32mProgress:   4%|█▊                                       |  ETA: 2:00:50[39m

Trial: 4351968, Curr_f: 1382484, d_f: 569


[32mProgress:   4%|█▊                                       |  ETA: 2:00:37[39m

Trial: 4420708, Curr_f: 1382495, d_f: 580


[32mProgress:   5%|█▉                                       |  ETA: 2:00:18[39m

Trial: 4501418, Curr_f: 1382503, d_f: 588


[32mProgress:   5%|█▉                                       |  ETA: 2:00:10[39m

Trial: 4538857, Curr_f: 1382508, d_f: 593


[32mProgress:   5%|█▉                                       |  ETA: 1:59:33[39m

Trial: 4721432, Curr_f: 1382514, d_f: 599


[32mProgress:   5%|██                                       |  ETA: 1:59:31[39m

Trial: 4732538, Curr_f: 1382550, d_f: 635


[32mProgress:   5%|██▏                                      |  ETA: 1:58:34[39m

Trial: 5043032, Curr_f: 1382552, d_f: 637


[32mProgress:   5%|██▏                                      |  ETA: 1:58:10[39m

Trial: 5121700, Curr_f: 1382584, d_f: 669


[32mProgress:   5%|██▏                                      |  ETA: 1:57:43[39m

Trial: 5218636, Curr_f: 1382598, d_f: 683


[32mProgress:   6%|██▍                                      |  ETA: 1:55:36[39m

Trial: 5775695, Curr_f: 1382643, d_f: 728


[32mProgress:   6%|██▌                                      |  ETA: 1:54:42[39m

Trial: 6062163, Curr_f: 1382650, d_f: 735


[32mProgress:   7%|██▋                                      |  ETA: 1:53:06[39m

Trial: 6551089, Curr_f: 1382652, d_f: 737


[32mProgress:   7%|██▉                                      |  ETA: 1:51:41[39m

Trial: 7017020, Curr_f: 1382688, d_f: 773


[32mProgress:   7%|███▏                                     |  ETA: 1:50:16[39m

Trial: 7497846, Curr_f: 1382690, d_f: 775


[32mProgress:   8%|███▏                                     |  ETA: 1:49:54[39m

Trial: 7631055, Curr_f: 1382702, d_f: 787


[32mProgress:   8%|███▍                                     |  ETA: 1:48:27[39m

Trial: 8300012, Curr_f: 1382721, d_f: 806


[32mProgress:   8%|███▍                                     |  ETA: 1:48:27[39m

Trial: 8300790, Curr_f: 1382725, d_f: 810


[32mProgress:   8%|███▌                                     |  ETA: 1:48:08[39m

Trial: 8443377, Curr_f: 1382733, d_f: 818


[32mProgress:   9%|███▊                                     |  ETA: 1:46:44[39m

Trial: 9107319, Curr_f: 1382779, d_f: 864


[32mProgress:  10%|███▉                                     |  ETA: 1:45:58[39m

Trial: 9539058, Curr_f: 1382796, d_f: 881


[32mProgress:  10%|████                                     |  ETA: 1:45:39[39m

Trial: 9730764, Curr_f: 1382797, d_f: 882


[32mProgress:  10%|████                                     |  ETA: 1:45:29[39m

Trial: 9827735, Curr_f: 1382810, d_f: 895


[32mProgress:  10%|████▎                                    |  ETA: 1:44:16[39m

Trial: 10413617, Curr_f: 1382827, d_f: 912


[32mProgress:  11%|████▎                                    |  ETA: 1:44:04[39m

Trial: 10507612, Curr_f: 1382833, d_f: 918


[32mProgress:  11%|████▍                                    |  ETA: 1:43:35[39m

Trial: 10738103, Curr_f: 1382889, d_f: 974


[32mProgress: 100%|█████████████████████████████████████████| Time: 3:40:25[39m


13225.725133 seconds (10.79 G allocations: 10.096 TiB, 5.49% gc time, 0.02% compilation time: 35% of which was recompilation)


In [15]:
f(X, W)

1382889