In [1]:
from simulator import blogm, bSqc, Neg, Sa
from mpo import tel_mpo
import torch
from math import prod
from functools import reduce
import pandas
from utils import dtype, device, pauli, basis, torch_data

In [2]:
seed = 0
train, test = False, True
file = f'seed{seed}'
train_ratio = 8/9
batch = 500

mdl = tel_mpo(34, bond=10)
total=0 # find size of the model
for p in mdl.parameters():
    total+=prod(p.shape)
total#, True_fid(mdl, psi)

6800

In [3]:
for N in range(4, 36, 2):
    torch.manual_seed(seed)
    prepseq, shadow_state, rhoS = torch_data(f'../data/data_{N}na.pickle', shuffle=True)
    train_size = int(prepseq.shape[0]*train_ratio)
    test_size = prepseq.shape[0]-train_size
    
    prepseq_train, prepseq_test = prepseq[:train_size], prepseq[train_size:]
    shadow_state_train, shadow_state_test = shadow_state[:train_size], shadow_state[train_size:]
    rhoS_train, rhoS_test = rhoS[:train_size], rhoS[train_size:]
    
    # split in batches
    prepseq_train = prepseq_train.view(-1, batch, N-2)
    shadow_state_train = shadow_state_train.view(-1, batch, 4)
    rhoS_train = rhoS_train.view(-1, batch, 4, 4)

    prepseq_test = prepseq_test.view(-1, batch, N-2)
    shadow_state_test = shadow_state_test.view(-1, batch, 4)
    rhoS_test = rhoS_test.view(-1, batch, 4, 4)
    
    # load new model
    # if N == 4:
    #     mdl = tel_mpo(34, bond=10)
    # else:
    #     mdl.load_state_dict(torch.load(f'{file}/models/mpo_N={N-2}_na.pt'))
    
    # load old model
    mdl.load_state_dict(torch.load(f'{file}/models/mpo_N={N}_na.pt'))
    
    optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-3) # 0.0001
    l = {'train Sqc':[], 'test Sqc':[], 'test Neg':[], 'test Sa':[], 'loss':[]}
    
    for epoch in range(1):
        # Train
        if train:
            print('='*50+'   Train   '+'='*50)
            mdl.train()
            for i in range(prepseq_train.shape[0]):
                rhoC = mdl(prepseq_train[i])
                l['train Sqc'].append(bSqc(rhoS_train[i], rhoC).mean().item())
                optimizer.zero_grad()
                probs = torch.bmm(torch.bmm(shadow_state_train[i].unsqueeze(1), rhoC), shadow_state_train[i].conj().unsqueeze(-1)).view(-1).real
                loss = -probs.log().mean()
                loss.backward()
                optimizer.step()
                l['loss'].append(loss.item())
                if (i+1)%100 == 0:
                    trainS = torch.tensor(l['train Sqc'])[-i:].mean().item()
                    loss = torch.tensor(l['loss'])[-i:].mean().item()
                    print('epoch:  %3d | step:  %3d | N:  %d | train Sqc: %.4f | loss: %.4f' %(epoch, i, N, trainS, loss))
        # Test
        if test:
            with torch.no_grad():
                print('='*50+'   Test   '+'='*50)
                mdl.eval()
                for i in range(prepseq_test.shape[0]):
                    rhoC = mdl(prepseq_test[i])
                    l['test Sqc'].append(bSqc(rhoS_test[i], rhoC).mean().item())
                    l['test Neg'].append(Neg(rhoS_test[i], rhoC).mean().item())
                    l['test Sa'].append(Sa(rhoS_test[i], rhoC).mean().item())
                    if (i+1)%100 == 0:
                        testS = torch.tensor(l['test Sqc'])[-i:].mean().item()
                        testN = torch.tensor(l['test Neg'])[-i:].mean().item()
                        print('epoch:  %3d | step:  %3d | N:  %d | test Sqc: %.4f | test Neg: %.4f' %(epoch, i, N, testS, testN))
        torch.save(l, f'{file}/record/mpo_N={N}_na.pt')
        torch.save(mdl.state_dict(), f'{file}/models/mpo_N={N}_na.pt')

epoch:    0 | step:   99 | N:  4 | test Sqc: 0.3515 | test Neg: 0.4203
epoch:    0 | step:  199 | N:  4 | test Sqc: 0.3657 | test Neg: 0.4145
epoch:    0 | step:   99 | N:  6 | test Sqc: 0.3963 | test Neg: 0.4073
epoch:    0 | step:  199 | N:  6 | test Sqc: 0.4094 | test Neg: 0.4028
epoch:    0 | step:   99 | N:  8 | test Sqc: 0.5226 | test Neg: 0.3698
epoch:    0 | step:  199 | N:  8 | test Sqc: 0.5207 | test Neg: 0.3695
epoch:    0 | step:   99 | N:  10 | test Sqc: 0.6314 | test Neg: 0.3313
epoch:    0 | step:  199 | N:  10 | test Sqc: 0.6347 | test Neg: 0.3293
epoch:    0 | step:   99 | N:  12 | test Sqc: 0.7272 | test Neg: 0.2954
epoch:    0 | step:  199 | N:  12 | test Sqc: 0.7248 | test Neg: 0.2948
epoch:    0 | step:   99 | N:  14 | test Sqc: 0.7338 | test Neg: 0.2931
epoch:    0 | step:  199 | N:  14 | test Sqc: 0.7476 | test Neg: 0.2852
epoch:    0 | step:   99 | N:  16 | test Sqc: 0.7923 | test Neg: 0.2677
epoch:    0 | step:  199 | N:  16 | test Sqc: 0.8025 | test Neg: 0.262