In [None]:
import torch
from torch.optim import Adam
from tqdm.auto import trange
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import trange, tqdm
from cloudmanufacturing.data import read_fatahi_dataset
from cloudmanufacturing.mip_solver import mip_solve
from cloudmanufacturing.validation import objvalue, construct_delta
from cloudmanufacturing.graph import dglgraph_fixed, graph_gamma, os_type, ss_type, so_type
from cloudmanufacturing.graphconv import GNN
import dgl
from dgl.dataloading import GraphDataLoader
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import precision_recall_curve, auc
import pickle
import torch.nn.functional as F
import logging

## Загрузка данных и оптимальных решений

In [None]:
with open('../data/train_data_sheet_names.pickle' , 'rb') as s:
    sheet_names = pickle.load(s)
dataset = read_fatahi_dataset('../data/train_data_OPTIMAL.xlsx', sheet_names)

In [None]:
with open('../data/train_data_solve.pickle' , 'rb') as f:
    # Загружаем dgl графы в список
    DGList = pickle.load(f)

## Настройка обучения
Пока нет полноценных тренировочных и тестовых данных смотрим, что обучение запускается, параметры batch_s и n_epochs - число задач в батче и число эпох обучения

In [None]:
batch_s = 2
n_epochs = 20
model_path = '../data/GNNmodel.pth'

In [None]:
loader = GraphDataLoader(
    DGList,
    batch_size=batch_s,
    drop_last=True,
    shuffle=True
)

In [None]:
logger = logging.getLogger("../data/GNN_run")
logger.setLevel(logging.INFO)

# create the logging file handler
fh = logging.FileHandler("../data/GNN_run.log")

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)

# add handler to logger object
logger.addHandler(fh)

In [None]:
num_batches = len(loader)
epoch_train_loss =0
train_loss = []
train_objvalue = []
oper_max = 20
model = GNN(ins_dim=1, ino_dim=oper_max, out_dim=16, n_layers=1)
optim = Adam(model.parameters(), lr=0.01)

try:
    logger.info('Start training')
    for i in range(n_epochs):
        for example_graph in tqdm(loader):
            logits = model(example_graph)
            example_target = example_graph.edata['target'][os_type]
            loss = F.binary_cross_entropy_with_logits(logits, example_target)
            batch_loss = loss.item()
            optim.zero_grad()
            batch_loss /= batch_s
            print(batch_loss)
            loss.backward()
            optim.step()
            logger.info(f"{i} epoch batch loss is {batch_loss}")
        
        train_loss.append(epoch_train_loss / num_batches)
        #train_objvalue.append(epoch_train_objvalue / num_batches)
        
    torch.save(model.state_dict(), model_path)
except:
    logger.exception('GNN training was not successful')