-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
41 lines (35 loc) · 1.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from library_imports import *
from utils import *
from train import *
random.seed(1)
torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('CUDA availability:', torch.cuda.is_available())
writer = SummaryWriter("./log/" + datetime.now().strftime("%Y%m%d-%H%M%S"))
# ------------------------------------------------------------------------------
# Specifying the dataset
# dataset_name = 'dataset_1'
dataset_name = 'dataset_2'
# ----------------------------------------------------
# Load data
dataset = load_data(dataset_name)
# ----------------------------------------------------
# Preprocess data
dataset = data_preprocessing(dataset)
# # ----------------------------------------------------
random.Random(1).shuffle(dataset)
print('Pytorch Geometric dataset has been shuffeled.')
# ----- Final Run ------
# config_selected = 'config1' #
# config_selected = 'config2' #
config_selected = 'config3' #
# config_selected = 'config4' #
# config_selected = 'config5' #
# config_selected = 'config6' #
# config_selected = 'config7' #
# config_selected = 'config8' #
# config_selected = 'config9' #
# config_selected = 'config10' #
# config_selected = 'config11' #
# config_selected = 'config12' #
model = train(dataset, writer, dataset_name, config_selected)