In [1]:
from tools import *

In [30]:
data = torch.load("dataset/60/1.pt")

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from math import ceil

model = CollectivMotoinModel(optim.Adam, nn.CrossEntropyLoss, lr=0.01).to(device)
epoches = 10
batch_size = 64
run_data = pd.DataFrame(
    {"ephoc": 0.0, "batch": 0.0, "loss": 0.0},
    index=range(epoches * ceil(len(train_data) / batch_size)),

In [None]:
for epoch in range(epoches):
    train_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    for i, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"epoch {epoch}, step {i}, loss {loss.item():.4f}")

In [None]:
class CollectivMotoinModel(nn.Module):
    """collective motion model predict the next location of the agents"""

    def __init__(
        self,
        optimizer: OptimazerType,
        loss_func: LossType,
        lr: float = 0.001,
        num_neighbors: int = 7,
    ):
        super().__init__()
        self.num_neighbors = num_neighbors
        self.model = nn.Sequential(
            nn.Linear(4 * (self.num_neighbors + 1), (self.num_neighbors + 1) * 8),
            nn.ReLU(),
            nn.Linear((self.num_neighbors + 1) * 8, 2 * (self.num_neighbors + 1)),
            nn.ReLU(),
            nn.Linear(2 * (self.num_neighbors + 1), 2),
        )
        self.optimizer = optimizer(self.parameters(), lr=lr)
        self.loss_func = loss_func()

    def forward(self, x: TensorType) -> TensorType:
        """predict the next location of the agents

        Args:
            x (TensorType): tensor (N, 4+4*m) of the current location, speed and volicity of the agents and relative location of the neighbors, speed and volicity

        Returns:
            TensorType: new location of the agents, given by x,y
        """
        return self.model(x)