In [2]:
import torch
import random
from collections import Counter
from itertools import combinations
import torch
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset

from model_utils import *
from base_utils import *

from collections import deque
import torch.multiprocessing as mp
from tqdm import tqdm

import os
import sys
import gc
import pickle
import argparse

In [12]:
def train_model(device, model, criterion, loader, nep, optimizer):
    #scaler = torch.cuda.amp.GradScaler()
    model.to(device)
    model.train()  # Set the model to training mode
    for epoch in range(nep):
        running_loss = 0.0
        
        for inputs, targets in tqdm(loader):
            inputs, targets = inputs.to(torch.float32).to(device), targets.to(torch.float32).to(device)
            
            optimizer.zero_grad()  # Zero the gradients
            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, targets)  # Calculate loss
            loss.backward()  # Backward pass
            optimizer.step()  # Optimize
            
            running_loss += loss.item()
        
        # Calculate the average loss per batch over the epoch
        epoch_loss = running_loss / len(loader)
        print(f"Epoch {epoch+1}/{nep}, Training Loss: {epoch_loss:.4f}")
    return model

def train_model_onehead(device, model, criterion, loader, nep, optimizer):
    #scaler = torch.cuda.amp.GradScaler()
    model.to(device)
    model.train()  # Set the model to training mode
    for epoch in range(nep):
        running_loss = 0.0
        
        for inputs, targets in tqdm(loader):
            if inputs.size(0) > 1:
                inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                
                outputs = model(inputs)[0]  # Forward pass
                loss = criterion(outputs, targets)  # Calculate loss
                loss.backward()  # Backward pass
                optimizer.step()  # Optimize
                optimizer.zero_grad()  # Zero the gradients
                
                running_loss += loss.item()
        
        # Calculate the average loss per batch over the epoch
        epoch_loss = running_loss / len(loader)
        print(f"Epoch {epoch+1}/{nep}, Training Loss: {epoch_loss:.4f}")
    return model

In [3]:
sXlist = [f for f in os.listdir('train') if 'SL_X' in f]
sXlist.sort()
sYlist = [f for f in os.listdir('train') if 'SL_Y' in f]
sYlist.sort()
qYlist = [f for f in os.listdir('train') if 'QV_Y' in f]
qYlist.sort()
len(sXlist), len(sYlist), len(qYlist)

(220, 220, 220)

In [31]:
n_history, n_feature = 15, 7
new_SLM = Network_Pcard_V2_2_BN_dropout(n_history+n_feature, n_feature, y=1, x=15, lstmsize=512, hiddensize=1024, dropout_rate=0.1)
new_QV = Network_Qv_Universal_V1_2_BN_dropout(input_size=(n_feature+1+2+1)*15,lstmsize=512, hsize=1024, dropout_rate=0.1) # action, lastmove, upper-lower state, action


In [52]:
(n_feature+1+2+1)*15 + 512

677

In [None]:
#[for transfer learning. 1: gather SL_X, SL_Y, QV_Y. 2: train SLM. 3: SL_X+SL_y_hat = QV_x_hat. 4: train QV]

In [22]:
for i in range(101,200):
    ds_sl = TensorDataset(torch.load(f'train/{sXlist[i]}').to(torch.float32),torch.load(f'train/{sYlist[i]}').to(torch.float32))
    #train_loader = DataLoader(ds_sl, batch_size=2048, shuffle=True, num_workers=0, pin_memory=True)
    if i == 0:
        train_loader = DataLoader(ds_sl, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
        opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-8)
        new_SLM = train_model_onehead('cuda', new_SLM, torch.nn.MSELoss(), train_loader,5, opt)
    else:
        train_loader = DataLoader(ds_sl, batch_size=2048, shuffle=True, num_workers=0, pin_memory=True)
        opt = torch.optim.Adam(new_SLM.parameters(), lr=0.0001, weight_decay=1e-8)
        new_SLM = train_model_onehead('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)
    #break

100%|██████████| 404/404 [00:07<00:00, 52.13it/s]


Epoch 1/1, Training Loss: 0.3047


100%|██████████| 406/406 [00:07<00:00, 53.52it/s]


Epoch 1/1, Training Loss: 0.3013


100%|██████████| 404/404 [00:07<00:00, 54.85it/s]


Epoch 1/1, Training Loss: 0.2984


100%|██████████| 403/403 [00:07<00:00, 54.31it/s]


Epoch 1/1, Training Loss: 0.3014


100%|██████████| 408/408 [00:07<00:00, 51.94it/s]


Epoch 1/1, Training Loss: 0.3022


100%|██████████| 408/408 [00:07<00:00, 52.97it/s]


Epoch 1/1, Training Loss: 0.3006


100%|██████████| 406/406 [00:07<00:00, 54.75it/s]


Epoch 1/1, Training Loss: 0.3044


100%|██████████| 406/406 [00:07<00:00, 55.59it/s]


Epoch 1/1, Training Loss: 0.3026


100%|██████████| 404/404 [00:07<00:00, 53.99it/s]


Epoch 1/1, Training Loss: 0.3028


100%|██████████| 404/404 [00:07<00:00, 54.70it/s]


Epoch 1/1, Training Loss: 0.3022


100%|██████████| 406/406 [00:07<00:00, 54.41it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 405/405 [00:07<00:00, 54.58it/s]


Epoch 1/1, Training Loss: 0.3047


100%|██████████| 406/406 [00:07<00:00, 55.17it/s]


Epoch 1/1, Training Loss: 0.3022


100%|██████████| 404/404 [00:07<00:00, 54.48it/s]


Epoch 1/1, Training Loss: 0.3000


100%|██████████| 408/408 [00:07<00:00, 54.37it/s]


Epoch 1/1, Training Loss: 0.3036


100%|██████████| 407/407 [00:07<00:00, 54.11it/s]


Epoch 1/1, Training Loss: 0.3025


100%|██████████| 407/407 [00:07<00:00, 54.51it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 407/407 [00:07<00:00, 54.56it/s]


Epoch 1/1, Training Loss: 0.3030


100%|██████████| 405/405 [00:07<00:00, 54.30it/s]


Epoch 1/1, Training Loss: 0.3016


100%|██████████| 404/404 [00:07<00:00, 54.32it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 406/406 [00:07<00:00, 54.73it/s]


Epoch 1/1, Training Loss: 0.3041


100%|██████████| 405/405 [00:07<00:00, 54.68it/s]


Epoch 1/1, Training Loss: 0.3013


100%|██████████| 407/407 [00:07<00:00, 54.67it/s]


Epoch 1/1, Training Loss: 0.3015


100%|██████████| 407/407 [00:07<00:00, 54.95it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 409/409 [00:07<00:00, 54.64it/s]


Epoch 1/1, Training Loss: 0.2994


100%|██████████| 406/406 [00:07<00:00, 54.71it/s]


Epoch 1/1, Training Loss: 0.3012


100%|██████████| 409/409 [00:07<00:00, 54.56it/s]


Epoch 1/1, Training Loss: 0.3040


100%|██████████| 407/407 [00:07<00:00, 54.62it/s]


Epoch 1/1, Training Loss: 0.3027


100%|██████████| 406/406 [00:07<00:00, 54.55it/s]


Epoch 1/1, Training Loss: 0.3015


100%|██████████| 406/406 [00:07<00:00, 54.52it/s]


Epoch 1/1, Training Loss: 0.3016


100%|██████████| 404/404 [00:07<00:00, 54.56it/s]


Epoch 1/1, Training Loss: 0.3016


100%|██████████| 405/405 [00:07<00:00, 54.45it/s]


Epoch 1/1, Training Loss: 0.3026


100%|██████████| 406/406 [00:07<00:00, 53.59it/s]


Epoch 1/1, Training Loss: 0.3033


100%|██████████| 403/403 [00:07<00:00, 52.30it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 406/406 [00:07<00:00, 53.27it/s]


Epoch 1/1, Training Loss: 0.3028


100%|██████████| 407/407 [00:07<00:00, 51.59it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 406/406 [00:08<00:00, 50.74it/s]


Epoch 1/1, Training Loss: 0.3015


100%|██████████| 409/409 [00:08<00:00, 50.77it/s]


Epoch 1/1, Training Loss: 0.2997


100%|██████████| 404/404 [00:07<00:00, 51.72it/s]


Epoch 1/1, Training Loss: 0.3004


100%|██████████| 406/406 [00:07<00:00, 52.47it/s]


Epoch 1/1, Training Loss: 0.3031


100%|██████████| 409/409 [00:07<00:00, 52.66it/s]


Epoch 1/1, Training Loss: 0.3026


100%|██████████| 409/409 [00:07<00:00, 52.55it/s]


Epoch 1/1, Training Loss: 0.3017


100%|██████████| 407/407 [00:07<00:00, 52.62it/s]


Epoch 1/1, Training Loss: 0.3004


100%|██████████| 407/407 [00:07<00:00, 52.78it/s]


Epoch 1/1, Training Loss: 0.3030


100%|██████████| 407/407 [00:07<00:00, 52.51it/s]


Epoch 1/1, Training Loss: 0.3027


100%|██████████| 408/408 [00:07<00:00, 52.99it/s]


Epoch 1/1, Training Loss: 0.3011


100%|██████████| 407/407 [00:07<00:00, 53.22it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 408/408 [00:07<00:00, 53.06it/s]


Epoch 1/1, Training Loss: 0.3025


100%|██████████| 405/405 [00:07<00:00, 52.73it/s]


Epoch 1/1, Training Loss: 0.3056


100%|██████████| 405/405 [00:07<00:00, 52.80it/s]


Epoch 1/1, Training Loss: 0.3035


100%|██████████| 407/407 [00:07<00:00, 53.07it/s]


Epoch 1/1, Training Loss: 0.3013


100%|██████████| 406/406 [00:07<00:00, 53.08it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 407/407 [00:07<00:00, 52.88it/s]


Epoch 1/1, Training Loss: 0.3002


100%|██████████| 403/403 [00:07<00:00, 52.71it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 408/408 [00:07<00:00, 52.79it/s]


Epoch 1/1, Training Loss: 0.3005


100%|██████████| 404/404 [00:07<00:00, 52.36it/s]


Epoch 1/1, Training Loss: 0.3020


100%|██████████| 407/407 [00:07<00:00, 52.38it/s]


Epoch 1/1, Training Loss: 0.3027


100%|██████████| 403/403 [00:07<00:00, 51.98it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 408/408 [00:07<00:00, 52.06it/s]


Epoch 1/1, Training Loss: 0.2996


100%|██████████| 411/411 [00:07<00:00, 52.41it/s]


Epoch 1/1, Training Loss: 0.3014


100%|██████████| 407/407 [00:07<00:00, 52.91it/s]


Epoch 1/1, Training Loss: 0.3001


100%|██████████| 408/408 [00:07<00:00, 53.04it/s]


Epoch 1/1, Training Loss: 0.3015


100%|██████████| 404/404 [00:07<00:00, 53.07it/s]


Epoch 1/1, Training Loss: 0.2986


100%|██████████| 405/405 [00:07<00:00, 53.13it/s]


Epoch 1/1, Training Loss: 0.2991


100%|██████████| 405/405 [00:07<00:00, 53.05it/s]


Epoch 1/1, Training Loss: 0.2991


100%|██████████| 408/408 [00:08<00:00, 49.87it/s]


Epoch 1/1, Training Loss: 0.2982


100%|██████████| 407/407 [00:07<00:00, 52.07it/s]


Epoch 1/1, Training Loss: 0.3002


100%|██████████| 407/407 [00:07<00:00, 52.30it/s]


Epoch 1/1, Training Loss: 0.3003


100%|██████████| 407/407 [00:07<00:00, 52.23it/s]


Epoch 1/1, Training Loss: 0.3027


100%|██████████| 405/405 [00:07<00:00, 52.96it/s]


Epoch 1/1, Training Loss: 0.3031


100%|██████████| 412/412 [00:07<00:00, 52.93it/s]


Epoch 1/1, Training Loss: 0.3010


100%|██████████| 406/406 [00:07<00:00, 53.26it/s]


Epoch 1/1, Training Loss: 0.3015


100%|██████████| 404/404 [00:07<00:00, 53.15it/s]


Epoch 1/1, Training Loss: 0.3029


100%|██████████| 406/406 [00:07<00:00, 52.64it/s]


Epoch 1/1, Training Loss: 0.3005


100%|██████████| 405/405 [00:07<00:00, 53.21it/s]


Epoch 1/1, Training Loss: 0.3013


100%|██████████| 404/404 [00:07<00:00, 52.81it/s]


Epoch 1/1, Training Loss: 0.2998


100%|██████████| 405/405 [00:07<00:00, 52.79it/s]


Epoch 1/1, Training Loss: 0.3011


100%|██████████| 407/407 [00:07<00:00, 52.74it/s]


Epoch 1/1, Training Loss: 0.2995


100%|██████████| 406/406 [00:07<00:00, 52.69it/s]


Epoch 1/1, Training Loss: 0.3005


100%|██████████| 404/404 [00:07<00:00, 53.37it/s]


Epoch 1/1, Training Loss: 0.3035


100%|██████████| 407/407 [00:07<00:00, 52.76it/s]


Epoch 1/1, Training Loss: 0.3012


100%|██████████| 402/402 [00:07<00:00, 52.51it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 407/407 [00:07<00:00, 52.80it/s]


Epoch 1/1, Training Loss: 0.3007


100%|██████████| 408/408 [00:07<00:00, 52.86it/s]


Epoch 1/1, Training Loss: 0.3034


100%|██████████| 406/406 [00:07<00:00, 53.11it/s]


Epoch 1/1, Training Loss: 0.3006


100%|██████████| 407/407 [00:07<00:00, 53.42it/s]


Epoch 1/1, Training Loss: 0.3019


100%|██████████| 409/409 [00:07<00:00, 52.63it/s]


Epoch 1/1, Training Loss: 0.3009


100%|██████████| 406/406 [00:07<00:00, 52.96it/s]


Epoch 1/1, Training Loss: 0.3001


100%|██████████| 411/411 [00:07<00:00, 52.78it/s]


Epoch 1/1, Training Loss: 0.3007


100%|██████████| 407/407 [00:07<00:00, 52.92it/s]


Epoch 1/1, Training Loss: 0.3022


100%|██████████| 405/405 [00:07<00:00, 52.96it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 403/403 [00:07<00:00, 52.64it/s]


Epoch 1/1, Training Loss: 0.3036


100%|██████████| 405/405 [00:07<00:00, 52.68it/s]


Epoch 1/1, Training Loss: 0.3021


100%|██████████| 407/407 [00:07<00:00, 52.77it/s]


Epoch 1/1, Training Loss: 0.3047


100%|██████████| 405/405 [00:07<00:00, 52.94it/s]


Epoch 1/1, Training Loss: 0.3034


100%|██████████| 406/406 [00:07<00:00, 52.79it/s]


Epoch 1/1, Training Loss: 0.3024


100%|██████████| 406/406 [00:07<00:00, 52.43it/s]


Epoch 1/1, Training Loss: 0.3013


100%|██████████| 408/408 [00:07<00:00, 52.24it/s]


Epoch 1/1, Training Loss: 0.3018


100%|██████████| 403/403 [00:07<00:00, 52.57it/s]

Epoch 1/1, Training Loss: 0.3024





In [58]:
# test 
new_SLM_f16 = new_SLM.to(torch.float16)
new_SLM_f16.eval()

for i in range(200,210):
    data_X = torch.load(f'train/{sXlist[i]}').to(torch.float16).to('cuda')
    data_Y = torch.load(f'train/{sYlist[i]}').to(torch.float16).to('cuda')
    dataset = TensorDataset(data_X, data_Y)
    batch_size = 2048
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    # Initialize loss function
    mse_loss = torch.nn.MSELoss()

    total_loss = 0.0
    total_batches = 0

    # Iterate through DataLoader to process in batches
    with torch.inference_mode():
        for batch_X, batch_Y in dataloader:
            pred_Y = new_SLM_f16(batch_X)[0]
            loss = mse_loss(pred_Y, batch_Y)  # Calculate loss for the batch
            total_loss += loss.item()
            total_batches += 1

    # Calculate average loss
    average_loss = total_loss / total_batches
    print(average_loss)
    #break

2.3400782216893563
2.344645381265281


KeyboardInterrupt: 

In [57]:
torch.save(new_SLM.state_dict(),os.path.join('test_models',f'SLM_H15-V2_3.1_{str(0).zfill(10)}.pt'))

In [34]:
torch.cuda.empty_cache()

In [47]:
165+512

677

In [56]:
new_QV

Network_Qv_Universal_V1_2_BN_dropout(
  (dropout): Dropout(p=0.1, inplace=False)
  (fc1): Linear(in_features=677, out_features=1024, bias=True)
  (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=1024, out_features=1024, bias=True)
  (bn2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=1024, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=1024, bias=True)
  (fc5): Linear(in_features=1024, out_features=1, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

In [50]:
# generate QV_x and train QV
new_SLM_f16 = new_SLM.to(torch.float16).to('cuda')
new_SLM_f16.eval()

new_QV.train()
new_QV.to('cuda')
opt = torch.optim.Adam(new_QV.parameters(), lr=0.0001, weight_decay=1e-8)
criterion = nn.MSELoss()

for i in range(0,10):
    sl_X = torch.load(f'train/{sXlist[i]}').to(torch.float16).to('cuda')
    qv_Y = torch.load(f'train/{qYlist[i]}').to(torch.float16).to('cuda')
    dataset = TensorDataset(sl_X, qv_Y)
    batch_size = 2048
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    running_loss = 0.0

    for batch_X, batch_Y in tqdm(loader):
        with torch.inference_mode():
            model_inter, qv_X = new_SLM_f16(batch_X)
        print(qv_X.shape,batch_X.shape,model_inter.shape, 8 * 15 + 2 * 15 + 512)
        qv_X = torch.concat([batch_X[:,0:8,0].view(batch_X.size(0), -1), # self
                                    model_inter, # upper and lower states
                                    #role,
                                    qv_X, # lstm encoded history
                                    ],dim=-1).to(torch.float32)
        print(qv_X.shape)
        
        opt.zero_grad()  # Zero the gradients
        outputs = new_QV(qv_X)  # Forward pass
        loss = criterion(outputs, batch_Y)  # Calculate loss
        loss.backward()  # Backward pass
        opt.step()  # Optimize
        
        running_loss += loss.item()
    
    # Calculate the average loss per batch over the epoch
    epoch_loss = running_loss / len(loader)
    print(f"Epoch 0 Training Loss: {epoch_loss:.4f}")
    break

  0%|          | 0/406 [00:00<?, ?it/s]

torch.Size([2048, 512]) torch.Size([2048, 22, 1, 15]) torch.Size([2048, 30]) 662
torch.Size([2048, 662])





RuntimeError: mat1 and mat2 shapes cannot be multiplied (2048x662 and 677x1024)

In [10]:
sllist = [f for f in os.listdir('train') if 'SLM' in f]
sllist.sort()
qvlist = [f for f in os.listdir('train') if 'QV' in f]
qvlist.sort()

In [11]:
new_SLM = Network_Pcard_V2_1_Trans(22, 7, y=1, x=15, trans_heads=4, trans_layers=6, hiddensize=512)

In [12]:
for i in range(3):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.0001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)
    #break

100%|██████████| 1643/1643 [00:24<00:00, 67.84it/s]


Epoch 1/1, Training Loss: nan


 46%|████▌     | 755/1645 [00:11<00:12, 68.55it/s]


KeyboardInterrupt: 

In [34]:
new_SLM = Network_Pcard_V2_1_BN(22, 7, y=1, x=15, lstmsize=256, hiddensize=512)
new_QV = Network_Qv_Universal_V1_1_BN(6,15,512)

In [43]:
for i in range(3):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)

    ds_qv = torch.load(f'train/{qvlist[i]}')
    train_loader = DataLoader(ds_qv, batch_size=512, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_QV.parameters(), lr=0.00001, weight_decay=1e-6)
    new_QV = train_model('cuda', new_QV, torch.nn.MSELoss(), train_loader, 1, opt)

    #break

100%|██████████| 1643/1643 [00:11<00:00, 146.57it/s]


Epoch 1/1, Training Loss: 0.3287


100%|██████████| 1643/1643 [00:06<00:00, 252.84it/s]


Epoch 1/1, Training Loss: 0.1460


100%|██████████| 1645/1645 [00:10<00:00, 155.62it/s]


Epoch 1/1, Training Loss: 0.3273


100%|██████████| 1645/1645 [00:06<00:00, 238.39it/s]


Epoch 1/1, Training Loss: 0.1439


100%|██████████| 1646/1646 [00:11<00:00, 149.04it/s]


Epoch 1/1, Training Loss: 0.3268


100%|██████████| 1646/1646 [00:07<00:00, 234.47it/s]

Epoch 1/1, Training Loss: 0.1446





In [44]:
torch.save(new_SLM.state_dict(),os.path.join('models',f'SLM_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))
torch.save(new_QV.state_dict(),os.path.join('models',f'QV_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))

In [39]:
compiled_new_SLM = torch.jit.load(os.path.join('models',f'SLM_H15-V2_2.3_{str(220025001).zfill(10)}.pt'))

In [30]:
new_SLM = Network_Pcard_V2_1(22, 7, y=1, x=15, lstmsize=256, hiddensize=256)
new_QV = Network_Qv_Universal_V1_1(6,15,256)

In [None]:
for i in range(12):
    ds_sl = torch.load(f'train/{sllist[i]}')
    train_loader = DataLoader(ds_sl, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_SLM.parameters(), lr=0.00001, weight_decay=1e-6)
    new_SLM = train_model('cuda', new_SLM, torch.nn.MSELoss(), train_loader, 1, opt)

    ds_qv = torch.load(f'train/{qvlist[i]}')
    train_loader = DataLoader(ds_qv, batch_size=64, shuffle=True, num_workers=0, pin_memory=True)
    opt = torch.optim.Adam(new_QV.parameters(), lr=0.00001, weight_decay=1e-6)
    new_QV = train_model('cuda', new_QV, torch.nn.MSELoss(), train_loader, 1, opt)

    #break