In [1]:
import numpy as np
import copy
from scipy.optimize import linprog
from collections import Counter

In [2]:
def NW_corner_method(a: np.array, b: np.array):
    a = copy.deepcopy(a)
    b = copy.deepcopy(b)

    m, n = len(a), len(b)
    X = np.zeros((m, n))
    B = []
    i, j = 0, 0
    
    while True:
        B.append((i, j))

        if a[i] > 0 and b[j] > 0:
            if a[i] > b[j]:
                a[i] -= b[j]
                X[i][j] = b[j]
                b[j] = 0
                j += 1
            else:
                b[i] -= a[i]
                X[i][j] = a[i]
                a[i] = 0
                i += 1
        elif a[i] == 0 and i != m - 1:
            i += 1
        else:
            j += 1
        
        if i == m or j == n:
            break
    
    assert len(B) == m + n - 1
    
    return X, B
    

In [3]:
def solve(a: np.array, b: np.array, C: np.array):
    assert (a >= 0).all()
    assert (b >= 0).all()
    assert (C >= 0).all()

    diff = np.sum(a) - np.sum(b)
    if diff > 0:
        b = np.append(b, diff)
        C = np.column_stack([C, np.zeros(len(a))])
    elif diff < 0:
        a = np.append(a, -diff)
        C = np.append(C, np.zeros(len(b)))
    
    m, n = len(a), len(b)

    X, B = NW_corner_method(a, b)

    while True:
        # 1, 2:
        A = np.zeros((m+n, m+n))
        b = np.zeros(m+n)
        for line, (i, j) in enumerate(B):
            A[line][i] = 1
            A[line][m+j] = 1
            b[line] = C[i][j]
        A[-1][0] = 1

        #3
        u_v = np.linalg.solve(A, b)
        u, v = u_v[:m], u_v[m:]

        # 4, 5:
        nB = []
        for i in range(m):
            for j in range(n):
                if (i, j) not in B:
                    nB.append((i, j))

        optimal_condition_met = True
        for (i, j) in nB:
            if u[i] + v[j] > C[i][j]:
                optimal_condition_met = False
                break
        
        if optimal_condition_met:  # 4
            return X
        
        B.append((i, j))  # 5

        # 6:
        B_copy = B.copy()
        while True:
            i_list, j_list = [*zip(*B_copy)]
            i_counter = Counter(i_list)
            j_counter = Counter(j_list)
            i_to_rm = [i for i in i_counter if i_counter[i] == 1 or i_counter[i] == 0]
            j_to_rm = [j for j in j_counter if j_counter[j] == 1 or j_counter[j] == 0]
            if not i_to_rm and not j_to_rm:
                break
            B_copy = [(i, j) for (i, j) in B_copy if i not in i_to_rm and j not in j_to_rm]

        # 7:
        plus_pairs = []
        minus_pairs = []
        plus_pairs.append(B_copy.pop())  

        while B_copy:
            if len(plus_pairs) - len(minus_pairs):
                for index, (i, j) in enumerate(B_copy):
                    if plus_pairs[-1][0] == i or plus_pairs[-1][1] == j:
                        minus_pairs.append(B_copy.pop(index))
                        break
            else:
                for index, (i, j) in enumerate(B_copy):
                    if minus_pairs[-1][0] == i or minus_pairs[-1][1] == j:
                        plus_pairs.append(B_copy.pop(index))
                        break

        theta = min([X[i][j] for (i, j) in minus_pairs])
        for (i, j) in plus_pairs:
            X[i][j] += theta
        for (i, j) in minus_pairs:
            X[i][j] -= theta

        # 8:
        for (i, j) in minus_pairs:
            if X[i][j] == 0:
                B.remove((i, j))
                break



In [4]:
tests = [
    {
        'a': np.array([100, 300, 300]),
        'b': np.array([300, 200, 200]),
        'C': np.array([[8, 4, 1],
                       [8, 4, 3],
                       [9, 7, 5]])
    },
    {
        'a': np.array([0, 0, 0]),
        'b': np.array([0, 0, 0]),
        'C': np.array([[0, 0, 0],
                       [0, 0, 0],
                       [0, 0, 0]])
    },
]

In [5]:
for test in tests:
    # result1 = linprog(-test['C'].flatten(), A_ub=np.zeros((len(test['a']), len(-test['C'].flatten()))), b_ub=test['a'], A_eq=np.zeros((len(test['b']), len(-test['C'].flatten()))),  b_eq=test['b'])
    result2 = solve(test['a'], test['b'], test['C'])
    print(result2)
    
    
    # if type(result2) is str:
    #     assert result1.message == 'The algorithm terminated successfully and determined that the problem is infeasible.'
    # else:
    #     assert np.allclose(result1.x, result2)

[[  0.   0. 100.]
 [  0. 200. 100.]
 [300.   0.   0.]]
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [6]:
t = [(1, 1), (2, 2), (3, 3), (6, 7)]
[*zip(*t)]

[(1, 2, 3, 6), (1, 2, 3, 7)]