In [41]:
import torch
import random
from collections import Counter
from itertools import combinations
import torch
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset

from model_utils import *
from base_utils import *

from collections import deque
import torch.multiprocessing as mp
from tqdm import tqdm

import os
import sys
import gc
import pickle
import argparse

In [42]:
def train_model(device, model, criterion, loader, nep, optimizer):
    #scaler = torch.cuda.amp.GradScaler()
    model.to(device)
    model.train()  # Set the model to training mode
    for epoch in range(nep):
        running_loss = 0.0
        
        for inputs, targets in tqdm(loader):
            inputs, targets = inputs.to(torch.float32).to(device), targets.to(torch.float32).to(device)
            
            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, targets)  # Calculate loss
            loss.backward()  # Backward pass
            optimizer.step()  # Optimize
            
            running_loss += loss.item()
        
        # Calculate the average loss per batch over the epoch
        epoch_loss = running_loss / len(loader)
        print(f"Epoch {epoch+1}/{nep}, Training Loss: {epoch_loss:.4f}")
    return model

def train_model_onehead(device, model, criterion, loader, nep, optimizer):
    #scaler = torch.cuda.amp.GradScaler()
    model.to(device)
    model.train()  # Set the model to training mode
    for epoch in range(nep):
        running_loss = 0.0
        
        for inputs, targets in tqdm(loader):
            if inputs.size(0) > 1:
                inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                
                outputs = model(inputs)[0]  # Forward pass
                loss = criterion(outputs, targets)  # Calculate loss
                loss.backward()  # Backward pass
                optimizer.step()  # Optimize
                optimizer.zero_grad()  # Zero the gradients
                
                running_loss += loss.item()
        
        # Calculate the average loss per batch over the epoch
        epoch_loss = running_loss / len(loader)
        print(f"Epoch {epoch+1}/{nep}, Training Loss: {epoch_loss:.4f}")
    return model

In [43]:
sXlist = [f for f in os.listdir('train') if 'SL_X' in f]
sXlist.sort()
sYlist = [f for f in os.listdir('train') if 'SL_Y' in f]
sYlist.sort()
qXlist = [f for f in os.listdir('train') if 'QV_X' in f]
qXlist.sort()
qYlist = [f for f in os.listdir('train') if 'QV_Y' in f]
qYlist.sort()
len(sXlist), len(sYlist), len(qXlist), len(qYlist)

(560, 560, 560, 560)

In [60]:
n_history, n_feature = 15, 7
new_SLM = Network_Pcard_V2_2_BN_dropout(n_history+n_feature, n_feature, y=1, x=15, lstmsize=512, hiddensize=1024, dropout_rate=0.1)

In [61]:
new_SLM.load_state_dict(torch.load(os.path.join('test_models',f'SLM_H15-V2_3.0L_{str(0).zfill(10)}.pt')))

<All keys matched successfully>

In [39]:
new_SLM.fc2.weight.data[0].mean().dtype

torch.float32

In [5]:
(n_feature+1+2+1)*15 + 512

677

In [None]:
#[for transfer learning. 1: gather SL_X, SL_Y, QV_Y. 2: train SLM. 3: SL_X+SL_y_hat = QV_x_hat. 4: train QV]

In [63]:
for i in range(100,200):
    ds_sl = TensorDataset(torch.load(f'train/{sXlist[i]}').to(torch.float32),torch.load(f'train/{sYlist[i]}').to(torch.float32))
    #train_loader = DataLoader(ds_sl, batch_size=2048, shuffle=True, num_workers=0, pin_memory=True)
    '''if i == 0:
        train_loader = DataLoader(ds_sl, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
        opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-8)
        new_SLM = train_model_onehead('cuda', new_SLM, torch.nn.MSELoss(), train_loader,5, opt)
    else:'''
    train_loader = DataLoader(ds_sl, batch_size=2048, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.001, weight_decay=1e-8)
    new_SLM = train_model_onehead('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)
    #break

100%|██████████| 416/416 [00:08<00:00, 51.11it/s]


Epoch 1/1, Training Loss: 0.3482


100%|██████████| 420/420 [00:08<00:00, 51.23it/s]


Epoch 1/1, Training Loss: 0.3310


100%|██████████| 417/417 [00:08<00:00, 49.91it/s]


Epoch 1/1, Training Loss: 0.3257


100%|██████████| 419/419 [00:08<00:00, 50.94it/s]


Epoch 1/1, Training Loss: 0.3205


100%|██████████| 421/421 [00:08<00:00, 51.21it/s]


Epoch 1/1, Training Loss: 0.3173


100%|██████████| 416/416 [00:08<00:00, 50.00it/s]


Epoch 1/1, Training Loss: 0.3203


100%|██████████| 420/420 [00:08<00:00, 51.07it/s]


Epoch 1/1, Training Loss: 0.3163


100%|██████████| 422/422 [00:08<00:00, 50.69it/s]


Epoch 1/1, Training Loss: 0.3157


100%|██████████| 417/417 [00:07<00:00, 52.31it/s]


Epoch 1/1, Training Loss: 0.3162


100%|██████████| 416/416 [00:08<00:00, 50.97it/s]


Epoch 1/1, Training Loss: 0.3176


100%|██████████| 421/421 [00:08<00:00, 52.28it/s]


Epoch 1/1, Training Loss: 0.3134


100%|██████████| 419/419 [00:08<00:00, 51.22it/s]


Epoch 1/1, Training Loss: 0.3146


100%|██████████| 418/418 [00:08<00:00, 51.85it/s]


Epoch 1/1, Training Loss: 0.3131


100%|██████████| 416/416 [00:08<00:00, 51.03it/s]


Epoch 1/1, Training Loss: 0.3139


100%|██████████| 417/417 [00:08<00:00, 51.41it/s]


Epoch 1/1, Training Loss: 0.3137


100%|██████████| 418/418 [00:08<00:00, 51.51it/s]


Epoch 1/1, Training Loss: 0.3095


100%|██████████| 419/419 [00:08<00:00, 51.70it/s]


Epoch 1/1, Training Loss: 0.3111


100%|██████████| 418/418 [00:08<00:00, 52.07it/s]


Epoch 1/1, Training Loss: 0.3127


100%|██████████| 417/417 [00:07<00:00, 52.75it/s]


Epoch 1/1, Training Loss: 0.3096


100%|██████████| 419/419 [00:08<00:00, 50.84it/s]


Epoch 1/1, Training Loss: 0.3083


100%|██████████| 417/417 [00:07<00:00, 52.39it/s]


Epoch 1/1, Training Loss: 0.3062


100%|██████████| 420/420 [00:08<00:00, 51.76it/s]


Epoch 1/1, Training Loss: 0.3083


100%|██████████| 416/416 [00:08<00:00, 47.37it/s]


Epoch 1/1, Training Loss: 0.3074


100%|██████████| 417/417 [00:08<00:00, 51.39it/s]


Epoch 1/1, Training Loss: 0.3055


100%|██████████| 419/419 [00:07<00:00, 52.75it/s]


Epoch 1/1, Training Loss: 0.3056


100%|██████████| 417/417 [00:08<00:00, 51.06it/s]


Epoch 1/1, Training Loss: 0.3062


100%|██████████| 420/420 [00:07<00:00, 53.01it/s]


Epoch 1/1, Training Loss: 0.3051


100%|██████████| 418/418 [00:08<00:00, 51.66it/s]


Epoch 1/1, Training Loss: 0.3074


100%|██████████| 417/417 [00:08<00:00, 51.91it/s]


Epoch 1/1, Training Loss: 0.3079


100%|██████████| 419/419 [00:08<00:00, 51.95it/s]


Epoch 1/1, Training Loss: 0.3043


100%|██████████| 418/418 [00:07<00:00, 53.51it/s]


Epoch 1/1, Training Loss: 0.3065


100%|██████████| 419/419 [00:08<00:00, 52.23it/s]


Epoch 1/1, Training Loss: 0.3055


100%|██████████| 419/419 [00:07<00:00, 52.47it/s]


Epoch 1/1, Training Loss: 0.3050


100%|██████████| 421/421 [00:08<00:00, 51.49it/s]


Epoch 1/1, Training Loss: 0.3019


100%|██████████| 422/422 [00:08<00:00, 52.65it/s]


Epoch 1/1, Training Loss: 0.3004


100%|██████████| 418/418 [00:08<00:00, 51.51it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 418/418 [00:07<00:00, 53.26it/s]


Epoch 1/1, Training Loss: 0.3037


100%|██████████| 417/417 [00:08<00:00, 52.05it/s]


Epoch 1/1, Training Loss: 0.3020


100%|██████████| 418/418 [00:07<00:00, 53.03it/s]


Epoch 1/1, Training Loss: 0.3040


100%|██████████| 419/419 [00:08<00:00, 51.32it/s]


Epoch 1/1, Training Loss: 0.3066


100%|██████████| 421/421 [00:07<00:00, 53.19it/s]


Epoch 1/1, Training Loss: 0.3047


100%|██████████| 419/419 [00:08<00:00, 51.87it/s]


Epoch 1/1, Training Loss: 0.3045


100%|██████████| 418/418 [00:07<00:00, 53.27it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 418/418 [00:08<00:00, 51.60it/s]


Epoch 1/1, Training Loss: 0.3077


100%|██████████| 415/415 [00:07<00:00, 53.17it/s]


Epoch 1/1, Training Loss: 0.3048


100%|██████████| 420/420 [00:08<00:00, 51.77it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 420/420 [00:07<00:00, 52.77it/s]


Epoch 1/1, Training Loss: 0.3013


100%|██████████| 419/419 [00:07<00:00, 52.54it/s]


Epoch 1/1, Training Loss: 0.3008


100%|██████████| 417/417 [00:07<00:00, 52.70it/s]


Epoch 1/1, Training Loss: 0.2995


100%|██████████| 418/418 [00:08<00:00, 51.85it/s]


Epoch 1/1, Training Loss: 0.3035


100%|██████████| 422/422 [00:08<00:00, 52.07it/s]


Epoch 1/1, Training Loss: 0.3017


100%|██████████| 419/419 [00:07<00:00, 52.49it/s]


Epoch 1/1, Training Loss: 0.3028


100%|██████████| 420/420 [00:08<00:00, 51.39it/s]


Epoch 1/1, Training Loss: 0.3019


100%|██████████| 416/416 [00:08<00:00, 51.68it/s]


Epoch 1/1, Training Loss: 0.3026


100%|██████████| 419/419 [00:08<00:00, 51.41it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 419/419 [00:08<00:00, 51.30it/s]


Epoch 1/1, Training Loss: 0.3025


100%|██████████| 417/417 [00:08<00:00, 51.79it/s]


Epoch 1/1, Training Loss: 0.3024


100%|██████████| 417/417 [00:08<00:00, 50.86it/s]


Epoch 1/1, Training Loss: 0.3008


100%|██████████| 419/419 [00:08<00:00, 51.13it/s]


Epoch 1/1, Training Loss: 0.3022


100%|██████████| 417/417 [00:08<00:00, 50.21it/s]


Epoch 1/1, Training Loss: 0.3008


100%|██████████| 421/421 [00:08<00:00, 51.19it/s]


Epoch 1/1, Training Loss: 0.3014


100%|██████████| 417/417 [00:08<00:00, 50.07it/s]


Epoch 1/1, Training Loss: 0.3036


100%|██████████| 423/423 [00:08<00:00, 52.64it/s]


Epoch 1/1, Training Loss: 0.3007


100%|██████████| 421/421 [00:08<00:00, 50.93it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 421/421 [00:07<00:00, 52.86it/s]


Epoch 1/1, Training Loss: 0.3009


100%|██████████| 416/416 [00:08<00:00, 50.59it/s]


Epoch 1/1, Training Loss: 0.3032


100%|██████████| 421/421 [00:07<00:00, 53.08it/s]


Epoch 1/1, Training Loss: 0.3012


100%|██████████| 421/421 [00:08<00:00, 52.24it/s]


Epoch 1/1, Training Loss: 0.3046


100%|██████████| 417/417 [00:08<00:00, 52.02it/s]


Epoch 1/1, Training Loss: 0.3007


100%|██████████| 420/420 [00:07<00:00, 52.89it/s]


Epoch 1/1, Training Loss: 0.3020


100%|██████████| 418/418 [00:07<00:00, 53.74it/s]


Epoch 1/1, Training Loss: 0.3003


100%|██████████| 422/422 [00:07<00:00, 52.98it/s]


Epoch 1/1, Training Loss: 0.2985


100%|██████████| 419/419 [00:07<00:00, 54.04it/s]


Epoch 1/1, Training Loss: 0.2987


100%|██████████| 420/420 [00:07<00:00, 53.46it/s]


Epoch 1/1, Training Loss: 0.2992


100%|██████████| 419/419 [00:07<00:00, 53.44it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 419/419 [00:07<00:00, 53.62it/s]


Epoch 1/1, Training Loss: 0.3023


100%|██████████| 418/418 [00:07<00:00, 53.70it/s]


Epoch 1/1, Training Loss: 0.3011


100%|██████████| 419/419 [00:07<00:00, 53.20it/s]


Epoch 1/1, Training Loss: 0.2978


100%|██████████| 418/418 [00:07<00:00, 53.52it/s]


Epoch 1/1, Training Loss: 0.3019


100%|██████████| 416/416 [00:07<00:00, 53.37it/s]


Epoch 1/1, Training Loss: 0.2998


100%|██████████| 419/419 [00:07<00:00, 53.21it/s]


Epoch 1/1, Training Loss: 0.3001


100%|██████████| 416/416 [00:07<00:00, 53.11it/s]


Epoch 1/1, Training Loss: 0.3017


100%|██████████| 419/419 [00:07<00:00, 52.93it/s]


Epoch 1/1, Training Loss: 0.2998


100%|██████████| 419/419 [00:07<00:00, 52.70it/s]


Epoch 1/1, Training Loss: 0.2997


100%|██████████| 417/417 [00:07<00:00, 53.16it/s]


Epoch 1/1, Training Loss: 0.2985


100%|██████████| 418/418 [00:08<00:00, 51.96it/s]


Epoch 1/1, Training Loss: 0.2970


100%|██████████| 416/416 [00:07<00:00, 56.26it/s]


Epoch 1/1, Training Loss: 0.2994


100%|██████████| 418/418 [00:07<00:00, 56.27it/s]


Epoch 1/1, Training Loss: 0.2972


100%|██████████| 418/418 [00:07<00:00, 56.97it/s]


Epoch 1/1, Training Loss: 0.2980


100%|██████████| 420/420 [00:07<00:00, 56.02it/s]


Epoch 1/1, Training Loss: 0.2977


100%|██████████| 417/417 [00:07<00:00, 57.29it/s]


Epoch 1/1, Training Loss: 0.2995


100%|██████████| 419/419 [00:07<00:00, 56.02it/s]


Epoch 1/1, Training Loss: 0.3008


100%|██████████| 417/417 [00:07<00:00, 57.34it/s]


Epoch 1/1, Training Loss: 0.2991


100%|██████████| 419/419 [00:07<00:00, 55.55it/s]


Epoch 1/1, Training Loss: 0.2985


100%|██████████| 419/419 [00:07<00:00, 57.27it/s]


Epoch 1/1, Training Loss: 0.3006


100%|██████████| 416/416 [00:07<00:00, 55.92it/s]


Epoch 1/1, Training Loss: 0.3020


100%|██████████| 418/418 [00:07<00:00, 57.29it/s]


Epoch 1/1, Training Loss: 0.2988


100%|██████████| 420/420 [00:07<00:00, 55.97it/s]


Epoch 1/1, Training Loss: 0.2987


100%|██████████| 415/415 [00:07<00:00, 57.18it/s]


Epoch 1/1, Training Loss: 0.2981


100%|██████████| 418/418 [00:07<00:00, 56.16it/s]

Epoch 1/1, Training Loss: 0.3004





In [47]:
# test 
new_SLM_f16 = new_SLM.to(torch.float16).to('cuda')
new_SLM_f16.eval()

for i in range(200,210):
    data_X = torch.load(f'train/{sXlist[i]}').to(torch.float16).to('cuda')
    data_Y = torch.load(f'train/{sYlist[i]}').to(torch.float16).to('cuda')
    dataset = TensorDataset(data_X, data_Y)
    batch_size = 2048
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    # Initialize loss function
    mse_loss = torch.nn.MSELoss()

    total_loss = 0.0
    total_batches = 0

    # Iterate through DataLoader to process in batches
    with torch.inference_mode():
        for batch_X, batch_Y in dataloader:
            pred_Y = new_SLM_f16(batch_X)[0]
            loss = mse_loss(pred_Y, batch_Y)  # Calculate loss for the batch
            total_loss += loss.item()
            total_batches += 1

    # Calculate average loss
    average_loss = total_loss / total_batches
    print(average_loss)
    #break

0.3173761276971726
0.31394217354910714
0.3171545175405649
0.3181621396740752
0.3116596818535249
0.3157996858218974
0.31510155087425595
0.31621017087484904
0.3145986815246437
0.31731739683014354


In [64]:
torch.save(new_SLM.state_dict(),os.path.join('test_models',f'SLM_H15-V2_3.0L_{str(0).zfill(10)}.pt'))

In [48]:
torch.cuda.empty_cache()

In [55]:
new_QV = Network_Qv_Universal_V1_2_BN_dropout(input_size=(n_feature+1+2+1)*15,lstmsize=512, hsize=1024, dropout_rate=0.1) # action, lastmove, upper-lower state, action
print('Sample wt QV',new_QV.fc2.weight.data[0].mean().item())

Sample wt QV 8.034409984247759e-05


In [56]:
new_QV,print('Sample wt QV',new_QV.fc2.weight.data[0].mean().item()),new_QV.fc2.weight.data[0].cpu().mean()

Sample wt QV 8.034409984247759e-05


(Network_Qv_Universal_V1_2_BN_dropout(
   (dropout): Dropout(p=0.1, inplace=False)
   (fc1): Linear(in_features=677, out_features=1024, bias=True)
   (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (fc2): Linear(in_features=1024, out_features=1024, bias=True)
   (bn2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (fc3): Linear(in_features=1024, out_features=1024, bias=True)
   (fc4): Linear(in_features=1024, out_features=1024, bias=True)
   (fc5): Linear(in_features=1024, out_features=1, bias=True)
   (flatten): Flatten(start_dim=1, end_dim=-1)
 ),
 None,
 tensor(8.0344e-05))

In [57]:
# generate QV_x and train QV
new_SLM_f16 = new_SLM.to(torch.float16).to('cuda')
new_SLM_f16.eval()


new_QV.to('cuda')
new_QV.train()
opt = torch.optim.Adam(new_QV.parameters(), lr=0.000001, weight_decay=1e-8)
criterion = nn.MSELoss()

for i in range(0,100):
    sl_X = torch.load(f'train/{sXlist[i]}').to(torch.float16).to('cuda')
    qv_Y = torch.load(f'train/{qYlist[i]}').to(torch.float32).to('cuda')
    qv_A = torch.load(f'train/{qXlist[i]}').to(torch.float16).to('cuda')

    dataset = TensorDataset(sl_X, qv_Y, qv_A)
    batch_size = 2048
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    running_loss = 0.0

    for batch_X, batch_Y, batch_A in tqdm(loader):

        with torch.inference_mode():
            model_inter, qv_X = new_SLM_f16(batch_X)
        #print(qv_X.shape,batch_X.shape,model_inter.shape, 8 * 15 + 2 * 15 + 512)
        qv_X = torch.concat([batch_X[:,0:8,0].view(batch_X.size(0), -1), # self
                                    model_inter, # upper and lower states
                                    #role,
                                    qv_X, # lstm encoded history
                                    batch_A
                                    ],dim=-1).to(torch.float32)
        #print(qv_X.shape)

        opt.zero_grad()  # Zero the gradients
        outputs = new_QV(qv_X)  # Forward pass
        loss = criterion(outputs, batch_Y)  # Calculate loss
        loss.backward()  # Backward pass
        opt.step()  # Optimize
        
        running_loss += loss.item()

    # Calculate the average loss per batch over the epoch
    epoch_loss = running_loss / len(loader)
    print(f"Epoch 0 Training Loss: {epoch_loss:.4f}")
    #print('Sample wt QV',new_QV.fc2.weight.data[0].mean().item())
    #break

100%|██████████| 422/422 [00:07<00:00, 59.33it/s]


Epoch 0 Training Loss: 0.2561


100%|██████████| 421/421 [00:07<00:00, 57.38it/s]


Epoch 0 Training Loss: 0.2496


100%|██████████| 421/421 [00:07<00:00, 60.12it/s]


Epoch 0 Training Loss: 0.2467


100%|██████████| 421/421 [00:06<00:00, 60.48it/s]


Epoch 0 Training Loss: 0.2470


100%|██████████| 419/419 [00:07<00:00, 59.72it/s]


Epoch 0 Training Loss: 0.2445


100%|██████████| 419/419 [00:07<00:00, 57.62it/s]


Epoch 0 Training Loss: 0.2422


100%|██████████| 418/418 [00:07<00:00, 56.78it/s]


Epoch 0 Training Loss: 0.2431


100%|██████████| 419/419 [00:07<00:00, 58.89it/s]


Epoch 0 Training Loss: 0.2420


100%|██████████| 418/418 [00:07<00:00, 58.39it/s]


Epoch 0 Training Loss: 0.2418


100%|██████████| 417/417 [00:07<00:00, 58.43it/s]


Epoch 0 Training Loss: 0.2395


100%|██████████| 421/421 [00:07<00:00, 57.41it/s]


Epoch 0 Training Loss: 0.2409


100%|██████████| 415/415 [00:07<00:00, 58.41it/s]


Epoch 0 Training Loss: 0.2405


100%|██████████| 418/418 [00:07<00:00, 57.70it/s]


Epoch 0 Training Loss: 0.2395


100%|██████████| 417/417 [00:07<00:00, 59.49it/s]


Epoch 0 Training Loss: 0.2403


100%|██████████| 416/416 [00:07<00:00, 58.57it/s]


Epoch 0 Training Loss: 0.2393


100%|██████████| 422/422 [00:07<00:00, 59.77it/s]


Epoch 0 Training Loss: 0.2392


100%|██████████| 417/417 [00:06<00:00, 59.95it/s]


Epoch 0 Training Loss: 0.2398


100%|██████████| 417/417 [00:07<00:00, 58.42it/s]


Epoch 0 Training Loss: 0.2407


100%|██████████| 418/418 [00:06<00:00, 60.20it/s]


Epoch 0 Training Loss: 0.2388


100%|██████████| 419/419 [00:06<00:00, 60.03it/s]


Epoch 0 Training Loss: 0.2376


100%|██████████| 421/421 [00:07<00:00, 56.93it/s]


Epoch 0 Training Loss: 0.2387


100%|██████████| 421/421 [00:07<00:00, 57.86it/s]


Epoch 0 Training Loss: 0.2381


100%|██████████| 420/420 [00:06<00:00, 60.81it/s]


Epoch 0 Training Loss: 0.2377


100%|██████████| 421/421 [00:06<00:00, 60.17it/s]


Epoch 0 Training Loss: 0.2392


100%|██████████| 416/416 [00:07<00:00, 58.55it/s]


Epoch 0 Training Loss: 0.2380


100%|██████████| 420/420 [00:07<00:00, 59.11it/s]


Epoch 0 Training Loss: 0.2383


100%|██████████| 421/421 [00:07<00:00, 58.29it/s]


Epoch 0 Training Loss: 0.2384


100%|██████████| 417/417 [00:06<00:00, 60.21it/s]


Epoch 0 Training Loss: 0.2382


100%|██████████| 418/418 [00:07<00:00, 59.11it/s]


Epoch 0 Training Loss: 0.2384


100%|██████████| 420/420 [00:06<00:00, 61.05it/s]


Epoch 0 Training Loss: 0.2379


100%|██████████| 420/420 [00:07<00:00, 59.46it/s]


Epoch 0 Training Loss: 0.2375


100%|██████████| 421/421 [00:06<00:00, 61.38it/s]


Epoch 0 Training Loss: 0.2383


100%|██████████| 417/417 [00:07<00:00, 56.48it/s]


Epoch 0 Training Loss: 0.2380


100%|██████████| 420/420 [00:06<00:00, 60.29it/s]


Epoch 0 Training Loss: 0.2382


100%|██████████| 417/417 [00:06<00:00, 61.66it/s]


Epoch 0 Training Loss: 0.2371


100%|██████████| 419/419 [00:07<00:00, 59.71it/s]


Epoch 0 Training Loss: 0.2375


100%|██████████| 417/417 [00:07<00:00, 59.47it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 417/417 [00:06<00:00, 59.87it/s]


Epoch 0 Training Loss: 0.2387


100%|██████████| 414/414 [00:06<00:00, 59.26it/s]


Epoch 0 Training Loss: 0.2359


100%|██████████| 423/423 [00:06<00:00, 61.01it/s]


Epoch 0 Training Loss: 0.2371


100%|██████████| 418/418 [00:07<00:00, 59.63it/s]


Epoch 0 Training Loss: 0.2370


100%|██████████| 422/422 [00:07<00:00, 56.07it/s]


Epoch 0 Training Loss: 0.2367


100%|██████████| 419/419 [00:07<00:00, 59.16it/s]


Epoch 0 Training Loss: 0.2380


100%|██████████| 419/419 [00:06<00:00, 60.12it/s]


Epoch 0 Training Loss: 0.2374


100%|██████████| 418/418 [00:07<00:00, 59.63it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 423/423 [00:07<00:00, 58.70it/s]


Epoch 0 Training Loss: 0.2387


100%|██████████| 416/416 [00:06<00:00, 60.61it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 420/420 [00:07<00:00, 59.39it/s]


Epoch 0 Training Loss: 0.2377


100%|██████████| 418/418 [00:06<00:00, 60.12it/s]


Epoch 0 Training Loss: 0.2358


100%|██████████| 417/417 [00:07<00:00, 58.46it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 418/418 [00:07<00:00, 59.11it/s]


Epoch 0 Training Loss: 0.2371


100%|██████████| 417/417 [00:06<00:00, 59.79it/s]


Epoch 0 Training Loss: 0.2371


100%|██████████| 418/418 [00:06<00:00, 60.18it/s]


Epoch 0 Training Loss: 0.2361


100%|██████████| 419/419 [00:07<00:00, 58.96it/s]


Epoch 0 Training Loss: 0.2371


100%|██████████| 420/420 [00:06<00:00, 60.62it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 417/417 [00:07<00:00, 59.38it/s]


Epoch 0 Training Loss: 0.2360


100%|██████████| 419/419 [00:06<00:00, 60.83it/s]


Epoch 0 Training Loss: 0.2364


100%|██████████| 419/419 [00:07<00:00, 59.46it/s]


Epoch 0 Training Loss: 0.2360


100%|██████████| 419/419 [00:06<00:00, 60.51it/s]


Epoch 0 Training Loss: 0.2363


100%|██████████| 421/421 [00:07<00:00, 57.53it/s]


Epoch 0 Training Loss: 0.2358


100%|██████████| 420/420 [00:07<00:00, 57.15it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 420/420 [00:07<00:00, 57.91it/s]


Epoch 0 Training Loss: 0.2370


100%|██████████| 421/421 [00:07<00:00, 58.78it/s]


Epoch 0 Training Loss: 0.2373


100%|██████████| 420/420 [00:07<00:00, 58.75it/s]


Epoch 0 Training Loss: 0.2356


100%|██████████| 420/420 [00:07<00:00, 58.56it/s]


Epoch 0 Training Loss: 0.2370


100%|██████████| 417/417 [00:06<00:00, 59.97it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 418/418 [00:06<00:00, 60.80it/s]


Epoch 0 Training Loss: 0.2373


100%|██████████| 416/416 [00:06<00:00, 59.87it/s]


Epoch 0 Training Loss: 0.2354


100%|██████████| 421/421 [00:06<00:00, 60.99it/s]


Epoch 0 Training Loss: 0.2374


100%|██████████| 421/421 [00:07<00:00, 59.63it/s]


Epoch 0 Training Loss: 0.2356


100%|██████████| 418/418 [00:07<00:00, 59.49it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 420/420 [00:07<00:00, 58.95it/s]


Epoch 0 Training Loss: 0.2364


100%|██████████| 419/419 [00:06<00:00, 60.87it/s]


Epoch 0 Training Loss: 0.2381


100%|██████████| 418/418 [00:07<00:00, 58.67it/s]


Epoch 0 Training Loss: 0.2370


100%|██████████| 418/418 [00:06<00:00, 60.11it/s]


Epoch 0 Training Loss: 0.2375


100%|██████████| 416/416 [00:07<00:00, 58.80it/s]


Epoch 0 Training Loss: 0.2354


100%|██████████| 419/419 [00:07<00:00, 59.21it/s]


Epoch 0 Training Loss: 0.2365


100%|██████████| 418/418 [00:07<00:00, 57.77it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 420/420 [00:07<00:00, 58.55it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 421/421 [00:07<00:00, 57.27it/s]


Epoch 0 Training Loss: 0.2365


100%|██████████| 419/419 [00:07<00:00, 56.24it/s]


Epoch 0 Training Loss: 0.2367


100%|██████████| 418/418 [00:07<00:00, 57.45it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 422/422 [00:07<00:00, 58.63it/s]


Epoch 0 Training Loss: 0.2366


100%|██████████| 417/417 [00:07<00:00, 57.57it/s]


Epoch 0 Training Loss: 0.2364


100%|██████████| 421/421 [00:07<00:00, 58.07it/s]


Epoch 0 Training Loss: 0.2365


100%|██████████| 418/418 [00:07<00:00, 58.12it/s]


Epoch 0 Training Loss: 0.2363


100%|██████████| 421/421 [00:07<00:00, 58.74it/s]


Epoch 0 Training Loss: 0.2360


100%|██████████| 421/421 [00:07<00:00, 57.80it/s]


Epoch 0 Training Loss: 0.2365


100%|██████████| 418/418 [00:07<00:00, 58.77it/s]


Epoch 0 Training Loss: 0.2368


100%|██████████| 417/417 [00:07<00:00, 58.39it/s]


Epoch 0 Training Loss: 0.2361


100%|██████████| 417/417 [00:07<00:00, 59.12it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 424/424 [00:07<00:00, 58.16it/s]


Epoch 0 Training Loss: 0.2360


100%|██████████| 431/431 [00:07<00:00, 58.67it/s]


Epoch 0 Training Loss: 0.2385


100%|██████████| 416/416 [00:07<00:00, 58.94it/s]


Epoch 0 Training Loss: 0.2355


100%|██████████| 423/423 [00:07<00:00, 59.11it/s]


Epoch 0 Training Loss: 0.2381


100%|██████████| 415/415 [00:07<00:00, 59.15it/s]


Epoch 0 Training Loss: 0.2369


100%|██████████| 418/418 [00:07<00:00, 58.33it/s]


Epoch 0 Training Loss: 0.2367


100%|██████████| 418/418 [00:07<00:00, 59.54it/s]


Epoch 0 Training Loss: 0.2364


100%|██████████| 417/417 [00:07<00:00, 58.70it/s]


Epoch 0 Training Loss: 0.2355


100%|██████████| 423/423 [00:07<00:00, 57.82it/s]

Epoch 0 Training Loss: 0.2363





In [58]:
torch.save(new_QV.state_dict(),os.path.join('test_models',f'QV_H15-V2_3.0L_{str(0).zfill(10)}.pt'))

In [10]:
sllist = [f for f in os.listdir('train') if 'SLM' in f]
sllist.sort()
qvlist = [f for f in os.listdir('train') if 'QV' in f]
qvlist.sort()

In [11]:
new_SLM = Network_Pcard_V2_1_Trans(22, 7, y=1, x=15, trans_heads=4, trans_layers=6, hiddensize=512)

In [12]:
for i in range(3):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.0001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)
    #break

100%|██████████| 1643/1643 [00:24<00:00, 67.84it/s]


Epoch 1/1, Training Loss: nan


 46%|████▌     | 755/1645 [00:11<00:12, 68.55it/s]


KeyboardInterrupt: 

In [34]:
new_SLM = Network_Pcard_V2_1_BN(22, 7, y=1, x=15, lstmsize=256, hiddensize=512)
new_QV = Network_Qv_Universal_V1_1_BN(6,15,512)

In [43]:
for i in range(3):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)

    ds_qv = torch.load(f'train/{qvlist[i]}')
    train_loader = DataLoader(ds_qv, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_QV.parameters(), lr=0.00001, weight_decay=1e-6)
    new_QV = train_model('cuda', new_QV, torch.nn.MSELoss(), train_loader, 1, opt)

    #break

100%|██████████| 1643/1643 [00:11<00:00, 146.57it/s]


Epoch 1/1, Training Loss: 0.3287


100%|██████████| 1643/1643 [00:06<00:00, 252.84it/s]


Epoch 1/1, Training Loss: 0.1460


100%|██████████| 1645/1645 [00:10<00:00, 155.62it/s]


Epoch 1/1, Training Loss: 0.3273


100%|██████████| 1645/1645 [00:06<00:00, 238.39it/s]


Epoch 1/1, Training Loss: 0.1439


100%|██████████| 1646/1646 [00:11<00:00, 149.04it/s]


Epoch 1/1, Training Loss: 0.3268


100%|██████████| 1646/1646 [00:07<00:00, 234.47it/s]

Epoch 1/1, Training Loss: 0.1446





In [44]:
torch.save(new_SLM.state_dict(),os.path.join('models',f'SLM_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))
torch.save(new_QV.state_dict(),os.path.join('models',f'QV_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))

In [39]:
compiled_new_SLM = torch.jit.load(os.path.join('models',f'SLM_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))

In [30]:
new_SLM = Network_Pcard_V2_1(22, 7, y=1, x=15, lstmsize=256, hiddensize=256)
new_QV = Network_Qv_Universal_V1_1(6,15,256)

In [None]:
for i in range(12):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)

    ds_qv = torch.load(f'train/{qvlist[i]}')
    train_loader = DataLoader(ds_qv, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_QV.parameters(), lr=0.00001, weight_decay=1e-6)
    new_QV = train_model('cuda', new_QV, torch.nn.MSELoss(), train_loader, 1, opt)

    #break