In [1]:
from simulator import blogm, bSqc, Neg, Sa
from Llama2 import LlamaPredictor
import torch
from math import prod
from functools import reduce
import pandas
from utils import dtype, device, pauli, basis, torch_data, shuffle

In [2]:
seed = 0
test = True
file = f'seed{seed}'
train_ratio = 5/6
batch = 500

In [3]:
prepseq_train = torch.load(f'../data/post_selected/prepseq_train.pt')
shadow_state_train = torch.load(f'../data/post_selected/shadow_state_train.pt')
rhoS_train = torch.load(f'../data/post_selected/rhoS_train.pt')

prepseq_test = torch.load(f'../data/post_selected/prepseq_test.pt')
shadow_state_test = torch.load(f'../data/post_selected/shadow_state_test.pt')
rhoS_test = torch.load(f'../data/post_selected/rhoS_test.pt')

In [4]:
# split in batches
prepseq_train = prepseq_train.view(-1, batch, 15)
shadow_state_train = shadow_state_train.view(-1, batch, 4)
rhoS_train = rhoS_train.view(-1, batch, 4, 4)

prepseq_test = prepseq_test.view(7, -1, batch, 15)
shadow_state_test = shadow_state_test.view(7, -1, batch, 4)
rhoS_test = rhoS_test.view(7, -1, batch, 4, 4)
    
prepseq_train.shape, prepseq_test.shape

(torch.Size([7000, 500, 15]), torch.Size([7, 200, 500, 15]))

In [5]:
mdl = LlamaPredictor(L_max=17,
                     n_embd=24, 
                     n_layer=12, 
                     n_head=6, 
                     vocab_size=4, 
                     dropout_prob=0.0).to(device)
optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-3) # 0.0001
l = {'train Sqc':[], 'test Sqc':[], 'test Neg':[], 'test Sa':[], 'loss':[]}
total=0 # find size of the model
for p in mdl.parameters():
    total+=prod(p.shape)
total

49976

In [6]:
# mdl = LlamaPredictor(L_max=17,
#                      n_embd=12, 
#                      n_layer=6, 
#                      n_head=6, 
#                      vocab_size=4, 
#                      dropout_prob=0.0).to(device)
# optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-3) # 0.0001
# l = {'train Sqc':[], 'test Sqc':[], 'test Neg':[], 'test Sa':[], 'loss':[]}
# total=0 # find size of the model
# for p in mdl.parameters():
#     total+=prod(p.shape)
# total

In [7]:
for epoch in range(20):
    # Train
    print('='*50+'   Train   '+'='*50)
    mdl.train()
    for i in range(prepseq_train.shape[0]):
        rhoC = mdl(prepseq_train[i])
        l['train Sqc'].append(bSqc(rhoS_train[i], rhoC).mean().item())
        optimizer.zero_grad()
        probs = torch.bmm(torch.bmm(shadow_state_train[i].unsqueeze(1), rhoC), shadow_state_train[i].conj().unsqueeze(-1)).view(-1).real
        loss = -probs.log().mean()
        loss.backward()
        optimizer.step()
        l['loss'].append(loss.item())
        if (i+1)%100 == 0:
            trainS = torch.tensor(l['train Sqc'])[-i:].mean().item()
            loss = torch.tensor(l['loss'])[-i:].mean().item()
            print('epoch:  %3d | step:  %3d | train Sqc: %.4f | loss: %.4f' %(epoch, i, trainS, loss))
    # Test
    if test:
        with torch.no_grad():
            print('='*50+'   Test   '+'='*50)
            mdl.eval()
            for n in range(prepseq_test.shape[0]):
                for i in range(prepseq_test.shape[1]):
                    N = n*2+4
                    rhoC = mdl(prepseq_test[n,i])
                    l['test Sqc'].append([N,bSqc(rhoS_test[n,i], rhoC).mean().item()])
                    l['test Neg'].append([N,Neg(rhoS_test[n,i], rhoC).mean().item()])
                    l['test Sa'].append([N,Neg(rhoS_test[n,i], rhoC).mean().item()])
                    if (i+1)%100 == 0:
                        testS = torch.tensor(l['test Sqc'])[-i:].mean(0)[-1].item()
                        testN = torch.tensor(l['test Neg'])[-i:].mean(0)[-1].item()
                        print('epoch:  %3d | step:  %3d | N:  %d | test Sqc: %.4f | test Neg: %.4f' %(epoch, i, N, testS, testN))
    torch.save(l, f'{file}/record/gpt_pa.pt')
    torch.save(mdl.state_dict(), f'{file}/models/gpt_pa.pt')

epoch:    0 | step:   99 | train Sqc: 1.4100 | loss: 1.3901
epoch:    0 | step:  199 | train Sqc: 1.3880 | loss: 1.3867
epoch:    0 | step:  299 | train Sqc: 1.3704 | loss: 1.3838
epoch:    0 | step:  399 | train Sqc: 1.3475 | loss: 1.3803
epoch:    0 | step:  499 | train Sqc: 1.3246 | loss: 1.3766
epoch:    0 | step:  599 | train Sqc: 1.3085 | loss: 1.3737
epoch:    0 | step:  699 | train Sqc: 1.2942 | loss: 1.3714
epoch:    0 | step:  799 | train Sqc: 1.2841 | loss: 1.3698
epoch:    0 | step:  899 | train Sqc: 1.2737 | loss: 1.3683
epoch:    0 | step:  999 | train Sqc: 1.2639 | loss: 1.3670
epoch:    0 | step:  1099 | train Sqc: 1.2570 | loss: 1.3658
epoch:    0 | step:  1199 | train Sqc: 1.2504 | loss: 1.3648
epoch:    0 | step:  1299 | train Sqc: 1.2430 | loss: 1.3638
epoch:    0 | step:  1399 | train Sqc: 1.2382 | loss: 1.3629
epoch:    0 | step:  1499 | train Sqc: 1.2328 | loss: 1.3619
epoch:    0 | step:  1599 | train Sqc: 1.2267 | loss: 1.3609
epoch:    0 | step:  1699 | train 