### Libraries import and configurations setup

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import itertools
import numpy as np
import math
import pickle
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("Using CPU")

Using CPU


### Class definitions

In [3]:
class MatrixDataset(Dataset):
    def __init__(self, matrices, labels):
        self.matrices = matrices
        self.labels = labels

    def __len__(self):
        return len(self.matrices)

    def __getitem__(self, idx):
        matrix = self.matrices[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return matrix, label
    
class Prey_Net(nn.Module):
    def __init__(self):
        super(Prey_Net, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(9, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 5)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class Predator_Net(nn.Module):
    def __init__(self):
        super(Predator_Net, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(9, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 5)

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def find_closest_cell(t, point, n):
    """
    t: torch tensor filled with 1s and 2s
    point: tuple (x,y) of the currently observed point
    n: number to search for (1 or 2)
    """
    closest_dist = math.inf
    closest_cell = None
    
    cells = [(i, j) for i in range(t.size()[0]) for j in range(t.size()[1]) if t[i][j] == n]
    
    for cell_point in cells:
        dist = abs(point[0] - cell_point[0]) + abs(point[1] - cell_point[1]) 
        if dist < closest_dist:
            closest_dist = dist
            closest_cell = cell_point
                    
    if closest_cell is None:
        return math.inf
    else:
        return closest_dist


def find_best_move(t, agent):
# Moves are enumerated as # Stand, Top, Right, Bottom, Left
    n = 1 if agent == 2 else 2
    center = (1,1)

    if agent == 1:
        best_distance = -math.inf
    else:
        best_distance = math.inf

    best_moves = []
    for move in [ (center[0]-1, center[1], 1),  (center[0], center[1]+1, 2), (center[0]+1, center[1], 3), (center[0], center[1]-1, 4)]:
        if (agent == 1 and t[move[0]][move[1]] == 0) or (agent == 2 and t[move[0]][move[1]] != 2):
            new_distance = find_closest_cell(t, move[:2], n)
            
            if (agent == 1 and new_distance > best_distance) or (agent == 2 and new_distance < best_distance):
                best_distance = new_distance
                best_moves = [move[2]]
            elif new_distance == best_distance:
                best_moves.append(move[2])
      

    if len(best_moves) == 0:
        return [1, 0, 0, 0, 0]
      
    prob = 1/len(best_moves)
    prob_vec = [0]*5

    for i in best_moves:
        prob_vec[i] = prob

    return prob_vec


def infer(net, t):
    net.eval()

    with torch.no_grad():  # Disable gradient calculation to save memory and computation
        logits = net(t.unsqueeze(0).to(device))
        probabilities = torch.softmax(logits, dim=1)

        return [el.item() for el in probabilities[0]] 

### State and label generation

In [4]:
# Create a list of all possible combinations of 0, 1, and 2
vals = [0, 1, 2]
combinations = list(itertools.product(vals, repeat=9))

# Filter out the combinations where the middle element is not 0
filtered_combinations = list(filter(lambda x: x[4] == 1, combinations))
prey_tensors = list(set([torch.tensor(combination, dtype=torch.float32).reshape(3, 3) for combination in filtered_combinations]))
prey_labels = [find_best_move(t, 1) for t in prey_tensors]  

filtered_combinations = list(filter(lambda x: x[4] == 2, combinations))
predator_tensors = list(set([torch.tensor(combination, dtype=torch.float32).reshape(3, 3) for combination in filtered_combinations]))
predator_labels = [find_best_move(t, 2) for t in predator_tensors]  

### Prey training

In [5]:
batch_size = 32
dataset = MatrixDataset(prey_tensors, prey_labels)
trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

prey_net = Prey_Net().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(prey_net.parameters(), lr=0.001)

num_epochs = 500

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = prey_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch: {epoch + 1}, Loss: {running_loss / len(trainloader)}")

print("Finished training")

Epoch: 1, Loss: 0.802518988751526
Epoch: 2, Loss: 0.3143258895898699
Epoch: 3, Loss: 0.28577952721671573
Epoch: 4, Loss: 0.27549128889476937
Epoch: 5, Loss: 0.2714681120462788
Epoch: 6, Loss: 0.2684641587502748
Epoch: 7, Loss: 0.25676783167494877
Epoch: 8, Loss: 0.25112998674429265
Epoch: 9, Loss: 0.25143046574034156
Epoch: 10, Loss: 0.24396315882419461
Epoch: 11, Loss: 0.24137751197496665
Epoch: 12, Loss: 0.23738075022300867
Epoch: 13, Loss: 0.24061919897071365
Epoch: 14, Loss: 0.24166367658712332
Epoch: 15, Loss: 0.23325348536945084
Epoch: 16, Loss: 0.23176847169519338
Epoch: 17, Loss: 0.23379070521558373
Epoch: 18, Loss: 0.23423640994192327
Epoch: 19, Loss: 0.23059312520024267
Epoch: 20, Loss: 0.22746848027522504
Epoch: 21, Loss: 0.22488924702069976
Epoch: 22, Loss: 0.22835881867021032
Epoch: 23, Loss: 0.22398213184444932
Epoch: 24, Loss: 0.22560802742260175
Epoch: 25, Loss: 0.22571890052516483
Epoch: 26, Loss: 0.22952328913492484
Epoch: 27, Loss: 0.2246825432094005
Epoch: 28, Loss:

### Predator training

In [6]:
batch_size = 32
dataset = MatrixDataset(predator_tensors, predator_labels)
trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

predator_net = Predator_Net().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(predator_net.parameters(), lr=0.001)

num_epochs = 500

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = predator_net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch: {epoch + 1}, Loss: {running_loss / len(trainloader)}")

print("Finished training")

Epoch: 1, Loss: 1.4006616166494426
Epoch: 2, Loss: 0.970431155226763
Epoch: 3, Loss: 0.5624161790875555
Epoch: 4, Loss: 0.4851560541654675
Epoch: 5, Loss: 0.4704221820693861
Epoch: 6, Loss: 0.47311316727145203
Epoch: 7, Loss: 0.46823258661818734
Epoch: 8, Loss: 0.46298563143224863
Epoch: 9, Loss: 0.46196499836260396
Epoch: 10, Loss: 0.4655613047116011
Epoch: 11, Loss: 0.4628882275767697
Epoch: 12, Loss: 0.45850987908389707
Epoch: 13, Loss: 0.45938827605574456
Epoch: 14, Loss: 0.4609707394271221
Epoch: 15, Loss: 0.46205239779162177
Epoch: 16, Loss: 0.46194052573257277
Epoch: 17, Loss: 0.45688112957224486
Epoch: 18, Loss: 0.45827145939602437
Epoch: 19, Loss: 0.4554665061290313
Epoch: 20, Loss: 0.4598648587187517
Epoch: 21, Loss: 0.45733796107258934
Epoch: 22, Loss: 0.4562915747316138
Epoch: 23, Loss: 0.4537683635660746
Epoch: 24, Loss: 0.4514201787736683
Epoch: 25, Loss: 0.4547569915218261
Epoch: 26, Loss: 0.4535793277245123
Epoch: 27, Loss: 0.4520368343128741
Epoch: 28, Loss: 0.44936142

### Getting the Q-tables and saving the data

In [7]:
def torch_tensor_to_str(arr):
    return '|'.join(''.join(str(int(i.item())) for i in row) for row in arr)

In [8]:
torch_tensor_to_str(predator_tensors[0]) 

'221|121|020'

In [9]:
prey_q_table = {torch_tensor_to_str(t) : infer(prey_net, t) for t in prey_tensors}
predator_q_table = {torch_tensor_to_str(t) : infer(predator_net, t) for t in predator_tensors}

In [10]:
with open('prey_q_table.pickle', 'wb') as f:
    pickle.dump(prey_q_table, f, protocol=pickle.HIGHEST_PROTOCOL)

with open('predator_q_table.pickle', 'wb') as f:
    pickle.dump(predator_q_table, f, protocol=pickle.HIGHEST_PROTOCOL)

In [14]:
with open("prey_q_table.pickle", "rb") as f:
    data = pickle.load(f)

In [12]:
data

{'221|121|020': [7.956546141940635e-07,
  7.01992852953151e-14,
  0.4974796175956726,
  6.019082530572123e-14,
  0.5025195479393005],
 '202|222|200': [1.025474731198974e-33,
  0.513136088848114,
  1.2476020572194102e-07,
  0.4865933656692505,
  0.00027034501545131207],
 '202|122|200': [2.4203714727683825e-38,
  1.4557047506968956e-05,
  1.4508736009485879e-12,
  9.605410923541058e-06,
  0.9999758005142212],
 '002|122|111': [1.2199057014681977e-17,
  6.517561246255354e-08,
  1.2749371870734993e-12,
  0.5033623576164246,
  0.49663764238357544],
 '202|022|200': [5.235951711949679e-41,
  0.3709825873374939,
  9.485867646930046e-10,
  0.3955477774143219,
  0.23346957564353943],
 '001|222|111': [1.0428168375143739e-17,
  3.125811417703517e-05,
  4.645442084933067e-14,
  0.9999687671661377,
  1.3026361322943103e-08],
 '011|021|011': [2.795443989340291e-15,
  0.3387818932533264,
  0.3277476727962494,
  0.33347025513648987,
  2.192073935702865e-07],
 '001|122|111': [1.0550236245318801e-17,
  1.

In [None]:
211
111
000

In [16]:
data['211|111|000']

[7.984173454560568e-21,
 3.657876952722905e-13,
 5.6942866422105e-07,
 0.9999992847442627,
 1.5603728797941585e-07]