# Обучение GNN-модели на датасете графов

In [None]:
%load_ext autoreload
%autoreload 2

from src.data.hydrodataset import HydroDataset

import torch
torch.autograd.set_detect_anomaly(True)

from src.models.gnnprocessor import GNNProcessor
from src.visualization.visualize import visualize_graph
from src.models.train_model import HydraulicsLoss

from torch_geometric.nn import summary
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx

import matplotlib.pyplot as plt

In [None]:
dataset = HydroDataset(root="/tmp/hydro")
train_ds = dataset
print("Dataset length:", len(train_ds))

In [None]:
G = to_networkx(dataset[0], to_undirected=False)
visualize_graph(G, color=dataset[0].x[..., 1] == 0);


In [None]:
loader = DataLoader(train_ds, batch_size=32)

In [None]:
model = GNNProcessor(out_channels=1, num_edge_features=dataset.num_edge_features, latent_dim=10, num_convs=20)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(model)

In [24]:
criterion = HydraulicsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

losses = []
def train():
  model.train()

  total_loss = 0
  for data in loader:
    data = data.to(device)
    optimizer.zero_grad()
    P, _, imbalance = model(data)
    loss = criterion(data, P, imbalance)    
    total_loss += loss.item() * data.num_graphs
    loss.backward()
  # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)  
    optimizer.step()

  total_loss = total_loss / len(loader.dataset)
  losses.append(total_loss)
  return total_loss

for epoch in range(250):
  loss = train()
  print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')  

TypeError: expected Tensor as element 2 in argument 0, but got list

In [None]:
plt.plot(losses)
plt.title('Кривая обучения')
plt.xlabel('Эпохи')
plt.ylabel('Функция потерь')
plt.show()