In [None]:
from dataloader import TrajectoryDataset
from torch_geometric.loader import DataLoader
import torch

max_peds = 10

train_dataset = TrajectoryDataset('./datasets/eth/train', max_peds=max_peds)
test_dataset = TrajectoryDataset('./datasets/eth/train', max_peds=max_peds)

# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 16

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, BatchNorm
import torch.nn as nn
import numpy as np

class GraphConvs(torch.nn.Module):
	def __init__(self, num_node_features, pred_len):
		super().__init__()
		self.conv1 = GCNConv(num_node_features, 256, improved=True)
		self.bn1 = BatchNorm(256)
		self.conv2 = GCNConv(256, 512, improved=True)
		self.bn2 = BatchNorm(512)
		self.conv3 = GCNConv(512, 256, improved=True)
		self.bn3 = BatchNorm(256)
		self.linear_conv = nn.Linear(256, pred_len * 2)

	def forward(self, graph):
		x, edge_index = graph.x, graph.edge_index
		x = self.conv1(x, edge_index)
		x = F.relu(x)
		# x = self.bn1(x)
		# x = F.dropout(x, training=self.training)
		x = self.conv2(x, edge_index)
		x = F.relu(x)
		# x = self.bn2(x)
		# x = F.dropout(x, training=self.training)
		x = self.conv3(x, edge_index)
		x = F.relu(x)
		# x = self.bn3(x)
		# x = F.dropout(x, training=self.training)
		x = self.linear_conv(x)
		return x

class GCN(torch.nn.Module):
	def __init__(self, num_node_features, max_peds, obs_len, pred_len, batch_size):
		super().__init__()

		self.max_peds = max_peds
		self.pred_len = pred_len
		self.obs_len = obs_len

		# graph convolutions
		self.graph_convs = nn.ModuleList([GraphConvs(num_node_features, pred_len) for _ in range(obs_len)])

		self.linear1 = nn.Linear(max_peds * 2 * pred_len, 64)
		self.linear2 = nn.Linear(64, 32)
		self.linear3 = nn.Linear(32, max_peds*24)

	def forward(self, sequence):
		graphs_out = [self.graph_convs[i](sequence[i]) for i in range(self.obs_len)]
		x = torch.concatenate(graphs_out)
		x = x.reshape((int(x.size(0)/(self.max_peds * self.obs_len)), self.obs_len, self.max_peds, 2, -1))
		x = torch.mean(x, dim=1)
		x = torch.flatten(x, start_dim=1)
		x = self.linear1(x)
		x = F.relu(x)
		# x = F.dropout(x, training=self.training)
		x = self.linear2(x)
		x = F.relu(x)
		# x = F.dropout(x, training=self.training)
		x = self.linear3(x)
		x = F.relu(x)
		# x = F.dropout(x, training=self.training)
		x = x.reshape((x.size(0), self.max_peds, 2, self.pred_len))
		return x

model = GCN(2, max_peds=max_peds, obs_len=8, pred_len=12, batch_size=batch_size)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

epochs = 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = model.to(device)

# Training
for epoch in range(epochs):
	train_loss = 0.0
	val_loss = 0.0

	train_num_batches = 0
	val_num_batches = 0
	for i, data in enumerate(trainloader, 0):
		model.train(True)
		input, target = data
		target = target.to(device)
		for j in range(len(input)):
			input[j] = input[j].to(device)
		# zero parameters gradients
		optimizer.zero_grad()
		output = model(input)
		# print(output.shape, target.shape)
		loss = criterion(output.double(), target.double())
		loss.backward()
		optimizer.step()
		train_loss += loss.item()
		train_num_batches += 1

	for i, data in enumerate(testloader, 0):
		model.eval()
		with torch.no_grad():
			input, target = data
			target = target.to(device)
			for j in range(len(input)):
				input[j] = input[j].to(device)
			# zero parameters gradients
			output = model(input)
			loss = criterion(output.double(), target.double())
			val_loss += loss.item()
			val_num_batches += 1

	print(f'epoch:{epoch}, running_loss:{train_loss/train_num_batches}')
	print(f'epoch:{epoch}, val_loss:{val_loss/val_num_batches}')
	print('----------------------------------------------------------')

In [None]:
data, target = next(iter(test_dataset))
for i in range(len(data)):
	data[i] = data[i].to(device)
target = target.to(device)

pred = model.forward(data)

import matplotlib.pyplot as plt

target = target.to('cpu').squeeze()
pred = pred.detach().to('cpu')

print(target)
plt.plot(target[0][0], target[0][1])
plt.show()
plt.plot(pred[0][0], pred[0][1])
plt.show()