In [8]:
import torch
from torch.utils.data import DataLoader
from datetime import datetime
from actor_50 import PtrNet1
from critic_50 import PtrNet2
from config_100n import Config
from data import Generator
from env_v4 import Env_tsp
import torch.nn as nn
from time import time

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim

class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=64
        )
        self.encoder_output_layer = nn.Linear(
            in_features=64, out_features=64
        )
        self.decoder_hidden_layer = nn.Linear(
            in_features=64, out_features=64
        )
        self.decoder_output_layer = nn.Linear(
            in_features=64, out_features=kwargs["input_shape"]
        )
        
    def encode(self, x):
        activation = self.encoder_hidden_layer(x)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.relu(code)
        return code
    
    def decode(self, code):
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.relu(activation)
        return reconstructed

    def forward(self, features):
        tmp = self.encode(features)
        reconstructed = self.decode(tmp)
        return reconstructed

In [10]:
!jupyter kernelspec list



Available kernels:
  myenv              /home/students/s290510/.local/share/jupyter/kernels/myenv
  python3            /home/students/s290510/.local/share/jupyter/kernels/python3
  venv               /home/students/s290510/.local/share/jupyter/kernels/venv
  venv_deep1         /home/students/s290510/.local/share/jupyter/kernels/venv_deep1
  ir                 /opt/anaconda3/envs/bigdatalab_cpu_202101/share/jupyter/kernels/ir
  octave             /opt/anaconda3/envs/bigdatalab_cpu_202101/share/jupyter/kernels/octave
  graphframe_yarn    /usr/local/share/jupyter/kernels/graphframe_yarn
  pyspark_local      /usr/local/share/jupyter/kernels/pyspark_local
  pyspark_yarn       /usr/local/share/jupyter/kernels/pyspark_yarn


In [11]:
cfg = Config()
env = Env_tsp(cfg)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = AE(input_shape=2).to(device)

model.load_state_dict(torch.load("model/vanilla_100n_64emb.pt"))
model.eval()

batch = env.get_batch_nodes(10)
print(batch.shape)



cpu
torch.Size([10, 100, 2])


In [12]:
def generate_encoding(batch):
    res = []
    for graph in batch:
        emb = model.encode(graph)
        res.append(emb)
        
    res_tens = torch.stack(res)
    return res_tens



In [13]:
#act_model = PtrNet1(cfg)
#act_model.load_state_dict(torch.load("model/100n_64emb_20ksteps_degenerated.pt"))


In [14]:
torch.backends.cudnn.benchmark = True

def train_model(cfg, env, log_path = None):
    try:    
        date = datetime.now().strftime('%m%d_%H_%M')

        torch.no_grad()
        act_model = PtrNet1(cfg)
        #act_model.load_state_dict(torch.load("model/100nodes-intermediate.pt"))


        if cfg.optim == 'Adam':
            act_optim = torch.optim.Adam(act_model.parameters(), lr = 0.001)
        if cfg.is_lr_decay:
            act_lr_scheduler = torch.optim.lr_scheduler.StepLR(act_optim, 
                            step_size=cfg.lr_decay_step, gamma=cfg.lr_decay)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        act_model = act_model.to(device)

        if cfg.mode == 'train':
            cri_model = PtrNet2(cfg)
            if cfg.optim == 'Adam':
                cri_optim = torch.optim.Adam(cri_model.parameters(), lr = 0.001)
            if cfg.is_lr_decay:
                cri_lr_scheduler = torch.optim.lr_scheduler.StepLR(cri_optim, 
                            step_size = cfg.lr_decay_step, gamma = cfg.lr_decay)
            cri_model = cri_model.to(device)
            ave_cri_loss = 0.

        mse_loss = nn.MSELoss()
        dataset = Generator(cfg, env)
        dataloader = DataLoader(dataset, batch_size = cfg.batch, shuffle = True)

        ave_act_loss, ave_L = 0., 0.
        min_L, cnt = 1e7, 0
        t1 = time()
        # for i, inputs in tqdm(enumerate(dataloader)):
        #for i, inputs in enumerate(dataloader):
        for i, inputs in enumerate(dataloader):
            inputs = inputs.to(device)
            #now instead of the inputs (a batch of node coordinates), give extract the embedding from the nodes
            embed = generate_encoding(inputs)  
            pred_tour, ll = act_model(embed, device)
            real_l = env.stack_l_fast(inputs, pred_tour)

            if cfg.mode == 'train':
                embed = generate_encoding(inputs)  
                pred_l = cri_model(embed, device)
                cri_loss = mse_loss(pred_l, real_l.detach())
                cri_optim.zero_grad()
                cri_loss.backward(retain_graph=True)
                nn.utils.clip_grad_norm_(cri_model.parameters(), max_norm = 1., norm_type = 2)
                cri_optim.step()
                if cfg.is_lr_decay:
                    cri_lr_scheduler.step()

            adv = real_l.detach() - pred_l.detach()
            act_loss = (adv * ll).mean()
            act_optim.zero_grad()
            act_loss.backward( )
            nn.utils.clip_grad_norm_(act_model.parameters(), max_norm = 1., norm_type = 2)
            act_optim.step()
            if cfg.is_lr_decay:
                act_lr_scheduler.step()

            ave_act_loss += act_loss.item()
            if cfg.mode == 'train':
                ave_cri_loss += cri_loss.item()
            ave_L += real_l.mean().item()

            if i % 5 == 0:
                env.show(inputs[0], pred_tour[0])
                #print(pred_tour[''])

            if i % cfg.log_step == 0:
                t2 = time()
                if cfg.mode == 'train':	
                    print('step:%d/%d, actic loss:%1.3f, critic loss:%1.3f, L:%1.3f, %dmin%dsec'%(i, cfg.steps, ave_act_loss/(i+1), ave_cri_loss/(i+1), ave_L/(i+1), (t2-t1)//60, (t2-t1)%60))
                    if cfg.islogger:
                        if log_path is None:
                            log_path = cfg.log_dir + '%s_%s_train.csv'%(date, cfg.task)#cfg.log_dir = ./Csv/
                            with open(log_path, 'w') as f:
                                f.write('step,actic loss,critic loss,average distance,time\n')
                        else:
                            with open(log_path, 'a') as f:
                                f.write('%d,%1.4f,%1.4f,%1.4f,%dmin%dsec\n'%(i, ave_act_loss/(i+1), ave_cri_loss/(i+1), ave_L/(i+1), (t2-t1)//60, (t2-t1)%60))


                if(ave_L/(i+1) < min_L):
                    min_L = ave_L/(i+1)
                if(ave_L/(i+1) < 126):
                    torch.save(act_model.state_dict(), cfg.model_dir + '%s_%s_step%d_act.pt'%(cfg.task, date, i))#'cfg.model_dir = ./Pt/'

                # else:
                # 	cnt += 1
                # 	print(f'cnt: {cnt}/20')
                # 	if(cnt >= 20):
                # 		print('early stop, average cost cant decrease anymore')
                # 		if log_path is not None:
                # 			with open(log_path, 'a') as f:
                # 				f.write('\nearly stop')
                # 		break
                t1 = time()
        if cfg.issaver:		
            torch.save(act_model.state_dict(), cfg.model_dir + '%s_%s_step%d_act.pt'%(cfg.task, date, i))#'cfg.model_dir = ./Pt/'
            print('save model...')
    except KeyboardInterrupt:
        torch.save(act_model.state_dict(), cfg.model_dir + '%s_%s_step%d_act.pt'%(cfg.task, date, i))#'cfg.model_dir = ./Pt/'
        

In [15]:
train_model(cfg, env)

UnboundLocalError: local variable 'i' referenced before assignment