In [1]:
import torch
import random
from collections import Counter
from itertools import combinations
import torch
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset

from model_utils import *
from base_utils import *
from base_funcs_selfplay import simEpisode_batchpool_softmax
#from test_batch_sim import *

from collections import deque
import torch.multiprocessing as mp
from tqdm import tqdm

import os
import sys
import gc
import pickle
import argparse

In [9]:
def train_model(device, model, criterion, loader, nep, optimizer):
    #scaler = torch.cuda.amp.GradScaler()
    model.to(device)
    model.train()  # Set the model to training mode
    for epoch in range(nep):
        running_loss = 0.0
        
        for inputs, targets in tqdm(loader):
            inputs, targets = inputs.to(torch.float32).to(device), targets.to(torch.float32).to(device)
            
            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, targets)  # Calculate loss
            loss.backward()  # Backward pass
            optimizer.step()  # Optimize
            
            running_loss += loss.item()
        
        # Calculate the average loss per batch over the epoch
        epoch_loss = running_loss / len(loader)
        print(f"Epoch {epoch+1}/{nep}, Training Loss: {epoch_loss:.4f}")
    return model

In [10]:
sllist = [f for f in os.listdir('train') if 'SLM' in f]
sllist.sort()
qvlist = [f for f in os.listdir('train') if 'QV' in f]
qvlist.sort()

In [11]:
new_SLM = Network_Pcard_V2_1_Trans(22, 7, y=1, x=15, trans_heads=4, trans_layers=6, hiddensize=512)

In [12]:
for i in range(3):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.0001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)
    #break

100%|██████████| 1643/1643 [00:24<00:00, 67.84it/s]


Epoch 1/1, Training Loss: nan


 46%|████▌     | 755/1645 [00:11<00:12, 68.55it/s]


KeyboardInterrupt: 

In [34]:
new_SLM = Network_Pcard_V2_1_BN(22, 7, y=1, x=15, lstmsize=256, hiddensize=512)
new_QV = Network_Qv_Universal_V1_1_BN(6,15,512)

In [43]:
for i in range(3):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)

    ds_qv = torch.load(f'train/{qvlist[i]}')
    train_loader = DataLoader(ds_qv, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_QV.parameters(), lr=0.00001, weight_decay=1e-6)
    new_QV = train_model('cuda', new_QV, torch.nn.MSELoss(), train_loader, 1, opt)

    #break

100%|██████████| 1643/1643 [00:11<00:00, 146.57it/s]


Epoch 1/1, Training Loss: 0.3287


100%|██████████| 1643/1643 [00:06<00:00, 252.84it/s]


Epoch 1/1, Training Loss: 0.1460


100%|██████████| 1645/1645 [00:10<00:00, 155.62it/s]


Epoch 1/1, Training Loss: 0.3273


100%|██████████| 1645/1645 [00:06<00:00, 238.39it/s]


Epoch 1/1, Training Loss: 0.1439


100%|██████████| 1646/1646 [00:11<00:00, 149.04it/s]


Epoch 1/1, Training Loss: 0.3268


100%|██████████| 1646/1646 [00:07<00:00, 234.47it/s]

Epoch 1/1, Training Loss: 0.1446





In [44]:
torch.save(new_SLM.state_dict(),os.path.join('models',f'SLM_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))
torch.save(new_QV.state_dict(),os.path.join('models',f'QV_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))

In [39]:
compiled_new_SLM = torch.jit.load(os.path.join('models',f'SLM_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))

In [30]:
new_SLM = Network_Pcard_V2_1(22, 7, y=1, x=15, lstmsize=256, hiddensize=256)
new_QV = Network_Qv_Universal_V1_1(6,15,256)

In [None]:
for i in range(12):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)

    ds_qv = torch.load(f'train/{qvlist[i]}')
    train_loader = DataLoader(ds_qv, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_QV.parameters(), lr=0.00001, weight_decay=1e-6)
    new_QV = train_model('cuda', new_QV, torch.nn.MSELoss(), train_loader, 1, opt)

    #break