In [None]:
import numpy as np
import torch
from torch import nn

In [None]:
N_SEATS = 24
N_COLS = 6
N_CONFLICTS = 40

In [None]:
def generate_seats_matrix(seed=None):
    np.random.seed(seed)
    seats_matrix = np.random.permutation(np.eye(N_SEATS, dtype=np.int8))
    return seats_matrix


seats_matrix = generate_seats_matrix(seed=0)
print(seats_matrix)

[[0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 0 0 0 0 0 0

In [None]:
def generate_conflicts_matrix(seed=None):
    np.random.seed(seed)

    conflicts_matrix = np.zeros((N_SEATS, N_SEATS), dtype=np.int8)
    possible_conflicts = np.array(list(zip(*np.where(~np.eye(conflicts_matrix.shape[0], dtype=bool)))))
    random_conflicts_index = np.random.choice(len(possible_conflicts), size=N_CONFLICTS, replace=False)
    random_conflicts = possible_conflicts[random_conflicts_index]
    conflicts_matrix[random_conflicts[:, 0], random_conflicts[:, 1]] = 1

    return conflicts_matrix


conflicts_matrix = generate_conflicts_matrix(seed=0)
assert conflicts_matrix.sum() == N_CONFLICTS

print(conflicts_matrix)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0]
 [0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0]
 [1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0 0 0]
 [0 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0 1 0 0 0 0 0 0

In [None]:
class Network(nn.Module):

    def __init__(self, hidden_size=512):
        super().__init__()
        self.fc1 = nn.Linear(N_SEATS * N_SEATS, hidden_size)
        self.fc2 = nn.Linear(hidden_size, N_SEATS * N_SEATS)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = torch.nn.functional.relu(x)

        x = self.fc2(x)
        x = torch.reshape(x, (-1, N_SEATS, N_SEATS))
        x = torch.nn.functional.softmax(x, dim=2)

        return x


def loss_fn(seats_matrix, conflicts_matrix):
    seats_matrix_t = seats_matrix.permute(0, 2, 1)

    left_seats = torch.roll(seats_matrix, shifts=-1, dims=2)
    left_seats[:, :, -1] = 0
    right_seats = torch.roll(seats_matrix, shifts=1, dims=2)
    right_seats[:, :, 0] = 0
    up_seats = torch.roll(seats_matrix, shifts=-N_COLS, dims=2)
    up_seats[:, :, -N_COLS:] = 0
    down_seats = torch.roll(seats_matrix, shifts=N_COLS, dims=2)
    down_seats[:, :, :N_COLS] = 0

    proximity_matrix = (left_seats + right_seats + up_seats + down_seats) @ seats_matrix_t
    proximity_conflicts_matrix = proximity_matrix * conflicts_matrix
    n_conflicts = proximity_conflicts_matrix.sum()
    loss_conflicts = n_conflicts

    sittings = seats_matrix.sum(dim=1)
    sittings_var = seats_matrix.var(dim=1)
    loss_sittings = ((torch.ones_like(sittings) - sittings) ** 2).sum() - sittings_var.sum()

    loss = loss_conflicts + loss_sittings
    return loss, n_conflicts, sittings


def decode_seats(predicted_seats_matrix, force_fill=False):
    if not force_fill:
        return predicted_seats_matrix.argmax(axis=2)

    all_seats = []

    for i in range(len(predicted_seats_matrix)):
        seats = []
        for arg_seats in (-predicted_seats_matrix).argsort(axis=2)[i]:
            for seat in arg_seats:
                if seat in seats:
                    continue
                seats.append(seat)
        all_seats.append(seats)

    all_seats = np.array(all_seats)
    return all_seats


def check_all_issues(predicted_seats_matrix, conflicts_matrix):
    decoded_seats = decode_seats(predicted_seats_matrix)
    n_issues = 0

    for j in range(len(conflicts_matrix)):
        pred_seats = decoded_seats[j]

        # sitting issues
        n_issues += len(pred_seats) - len(np.unique(pred_seats))

        # conflict issues
        for i, seat in enumerate(pred_seats):
            conflict_indices = np.where(conflicts_matrix[j][i])[0]
            seats_diffs = np.abs(pred_seats[conflict_indices] - seat)

            n_issues += (seats_diffs == 1).sum()
            n_issues += (seats_diffs == N_COLS).sum()

    return n_issues



network = Network()

conflicts_matrix = generate_conflicts_matrix(seed=0)[np.newaxis, ...]
torch_conflicts_matrix = torch.from_numpy(conflicts_matrix).float()
predicted_seats_matrix = network(torch_conflicts_matrix)

print('loss:', loss_fn(predicted_seats_matrix, torch_conflicts_matrix)[0].item())
print('issues:', check_all_issues(predicted_seats_matrix.detach().numpy(), conflicts_matrix))

loss: 5.702538967132568
issues: 20


In [None]:
n_epochs = 200

network = Network()

optimizer = torch.optim.Adam(network.parameters(), lr=0.001)

conflicts_matrix = np.array([generate_conflicts_matrix() for _ in range(32)])

for epoch in range(n_epochs):
    torch_conflicts_matrix = torch.from_numpy(conflicts_matrix).float()
    predicted_seats_matrix = network(torch_conflicts_matrix)
    loss, n_conflicts, sittings = loss_fn(predicted_seats_matrix, torch_conflicts_matrix)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    final_issues = check_all_issues(predicted_seats_matrix.detach().numpy(), conflicts_matrix)

    print(f"epoch={epoch:04d}, loss={loss.item():.10f}, conflicts={n_conflicts.item():.10f}, sittings={sittings.sum().item():.10f}, issues={final_issues}")

epoch=0000, loss=182.6718750000, conflicts=182.5465240479, sittings=768.0000000000, issues=462
epoch=0001, loss=181.0302429199, conflicts=180.0335845947, sittings=767.9999389648, issues=535
epoch=0002, loss=181.0392608643, conflicts=178.2886199951, sittings=768.0000000000, issues=555
epoch=0003, loss=180.8373565674, conflicts=177.9104614258, sittings=768.0000000000, issues=546
epoch=0004, loss=180.5297546387, conflicts=178.2104797363, sittings=768.0000000000, issues=561
epoch=0005, loss=180.3681640625, conflicts=178.6425018311, sittings=768.0000000000, issues=556
epoch=0006, loss=180.2866363525, conflicts=178.9318237305, sittings=768.0000000000, issues=534
epoch=0007, loss=180.1852569580, conflicts=179.0102844238, sittings=768.0000000000, issues=493
epoch=0008, loss=180.0308990479, conflicts=178.8942108154, sittings=768.0000610352, issues=467
epoch=0009, loss=179.8346557617, conflicts=178.6159057617, sittings=768.0000610352, issues=444
epoch=0010, loss=179.6218872070, conflicts=178.217