In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Sampler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random


# from torchmetrics import Accuracy

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device", flush=True)

seed: int = 0
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

epochs = 100
learning_rate = 1e-3
batch_size = 16

In [None]:
# load the data in the output folder

data = np.load("/Users/ens/repos/marl/output/_trainingdata_groundmodel_exploit_True_numepi100_K10_L10_M2_N10_T10.npy", allow_pickle=True).item()

In [None]:
data.keys()

In [None]:
len(data["states"][0])

In [None]:
states = data["states"][0]

In [None]:
states[0].shape

In [None]:
actions = data["actions"][0]

In [None]:
actions[0].shape

In [None]:
n_input = states[0].shape[0]
n_hidden = 256
n_out = actions[0].shape[0]

In [None]:
n_input, n_hidden, n_out

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self, n_input, n_hidden, n_out):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_input, n_hidden)
        self.fc2 = nn.Linear(n_hidden, n_hidden)
        self.fc3 = nn.Linear(n_hidden, n_out)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

net = Net(n_input, n_hidden, n_out)
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, states, actions):
        self.states = states
        self.actions = actions

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return self.states[idx], self.actions[idx]
    
dataset = CustomDataset(states, actions)

train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.float().to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, loss: {running_loss}", flush=True)

torch.save(net.state_dict(), "/Users/ens/repos/marl/output/groundmodel.pt")