In [2]:
import os
import torch
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from utils import *
from model import *

seed=816
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
# ---------------------
# parameters
# ---------------------
lr = 2e-3
epochs = 100
batch_size = 1
pos_weights = 7
path = '../../../data/2023-graph-conflation/'

# ---------------------
# load data
# ---------------------
print('Load Datasets...')
files = os.listdir(path+'/graphs/osm/')
train_files, test_files = train_test_split(files, test_size=0.2, random_state=42)
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

# make datasets
train_data = GraphDataset(path, train_files)
val_data = GraphDataset(path, val_files)
test_data = GraphDataset(path, test_files)

# data loader
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

Load Datasets...


In [7]:
# ---------------------
#  models
# ---------------------
print('Load Model...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_osm = GraphUNet(2,32,64,3).to(device)
model_sdot = GraphUNet(2,32,64,3).to(device)
optimizer = torch.optim.Adam(list(model_osm.parameters()) + list(model_sdot.parameters()), lr=lr)
criterion= nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weights))
es = EarlyStopping(tolerance=10)

Load Model...


In [8]:
# ----------------
# Training
# ----------------
model_osm.train()
model_sdot.train()
Train(
    train_dataloader, 
    model_osm, 
    model_sdot, 
    optimizer, 
    criterion,
    device
)

  4%|█▍                                      | 210/5672 [00:02<00:55, 97.94it/s]

KeyboardInterrupt

