# Critic 预训练

In [1]:
import os

import torch
import torch.optim as optim
import torch.nn.functional as F

from PFSP import pfspStep
from PFSPNet import default_config, PFSPNet, PFSPDataLoader

from torch.utils.tensorboard import SummaryWriter

from datetime import datetime

# Config

In [9]:
# problem parameters
n = 10
m = 20
n_step = 5

# train parameters
n_epoch = 1000
dataset_size = 1000
batch_size = 1000

opt = 'adam'
lr = 1e-5

# valid parameters
validset_size = 500

# model parameters
default_config['n_layers'] = 3

# Train

In [3]:
model = PFSPNet(default_config)

In [10]:
trainLoader = PFSPDataLoader(dataset_size, batch_size, n, m, PRETrainCritic=True)
validLoader = PFSPDataLoader(validset_size, validset_size, n, m, PRETrainCritic=True)

if opt == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=lr)
elif opt == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=lr)

if torch.cuda.is_available():
    model = model.cuda()

In [11]:
TIMESTAMP = f'{datetime.now():%Y-%m-%dT%H-%M-%S}'
CONFIGSTAMP = f"n{n}to{n-n_step}_m{m}"

writer = SummaryWriter(log_dir='tb_logs/pretrain_critic/' + CONFIGSTAMP + "/" + TIMESTAMP + "/")

In [12]:
k = 0

In [13]:
for epoch in range(n_epoch):
    
    model.train()
    
    for P, state, label in trainLoader:
        batch_size = P.shape[0]
        n = P.shape[1]

        baselineN = []

        for i in range(n, n - n_step, -1):
            probs, baseline = model(P, state)
            # probs: [batch_size, n]
            # baselines: [batch_size]

            # 根据概率分布随机选取下一步
            idx = torch.multinomial(probs.detach(), num_samples=1).view(-1)
            prob = probs[range(batch_size), idx]
            # prob: [batch_size]

            baselineN.append(baseline)

            J = P[range(batch_size), idx, :]
            state_next = pfspStep(J, state)

            idx_01 = F.one_hot(idx, num_classes=i).view(-1)
            P_next = P.view(-1, P.shape[-1])[idx_01 == 0, :].view(-1, i - 1, P.shape[-1])

            P = P_next
            state = state_next

        baselineN = torch.stack(baselineN, dim=1)
        # [batch_size, n_step]

        # Train Critic
        loss = ((label.view(-1, 1) - baselineN)**2).sum() / batch_size

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        k += 1
        writer.add_scalars("loss", {"train": loss}, k)
    
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            for P, state, label in validLoader:
                batch_size = P.shape[0]
                n = P.shape[1]

                baselineN = []

                for i in range(n, n - n_step, -1):
                    probs, baseline = model(P, state)
                    # probs: [batch_size, n]
                    # baselines: [batch_size]

                    # 根据概率分布随机选取下一步
                    idx = torch.multinomial(probs.detach(), num_samples=1).view(-1)
                    prob = probs[range(batch_size), idx]
                    # prob: [batch_size]

                    baselineN.append(baseline)

                    J = P[range(batch_size), idx, :]
                    state_next = pfspStep(J, state)

                    idx_01 = F.one_hot(idx, num_classes=i).view(-1)
                    P_next = P.view(-1, P.shape[-1])[idx_01 == 0, :].view(-1, i - 1, P.shape[-1])

                    P = P_next
                    state = state_next

                baselineN = torch.stack(baselineN, dim=1)
                # [batch_size, n_step]

                # Train Critic
                loss = ((label.view(-1, 1) - baselineN)**2).sum() / batch_size

                writer.add_scalars("loss", {"valid": loss}, k)

In [8]:
writer.close()

In [18]:
# save model
path = "models/pretrain_critic/" + CONFIGSTAMP

folder = os.path.exists(path)

if not folder:
    os.makedirs(path)

torch.save(model.state_dict(), path + "/" + TIMESTAMP + ".pt")