# LLaMA Fine-Tuning for Quantum State Reconstruction

This notebook performs fine-tuning of pre-trained LLaMA transformer models for quantum state reconstruction across different system sizes (N=4 to N=34).

**Training Configuration:**
- **Data Split**: 8/9 training data, 1/9 test data (train_ratio = 8/9)
- **Fine-Tuning Approach**: Loads pre-trained models for each system size and performs additional training/refinement
- **Sequential Processing**: Processes each system size individually, loading the corresponding pre-trained model
- **Metrics**: Quantum coherence (Sqc), negativity (Neg), and Sa measures for performance evaluation

**Training Approach Comparison:**
- **This notebook (Fine-Tuning)**: Loads pre-trained specialized models for each system size and performs additional refinement training on individual system data.
- **Sequential GPT**: Trains models from scratch with transfer learning initialization.
- **Unified GPT**: Trains a single model on all system sizes simultaneously.

The fine-tuning approach allows for specialized refinement of already-trained models while maintaining the benefits of system-specific optimization.


In [1]:
from simulator import blogm, bSqc, Neg, Sa
from Llama2 import LlamaPredictor
import torch
from math import prod
from functools import reduce
import pandas
from utils import dtype, device, pauli, basis, torch_data

In [2]:
seed = 1
train, test = False, True
file = f'seed{seed}'
train_ratio = 8/9
batch = 500

# mdl = LlamaPredictor(L_max=35,
#                      n_embd=12, 
#                      n_layer=6, 
#                      n_head=6, 
#                      vocab_size=4, 
#                      dropout_prob=0.0).to(device)
mdl = LlamaPredictor(L_max=35,
                     n_embd=24, 
                     n_layer=12, 
                     n_head=6, 
                     vocab_size=4, 
                     dropout_prob=0.0).to(device)
#mdl.load_state_dict(torch.load(f'{file}/models/gpt_na.pt'))
    
total=0 # find size of the model
for p in mdl.parameters():
    total+=prod(p.shape)
total#, True_fid(mdl, psi)

49976

In [4]:
for N in range(4,36,2):
    # if N > 4:
    #     mdl.load_state_dict(torch.load(f'{file}/models/gpt_N={N-2}_na.pt'))
    mdl.load_state_dict(torch.load(f'{file}/models/gpt_N={N}_na.pt'))
    #mdl.load_state_dict(torch.load(f'{file}/models/gpt_na.pt'))
    torch.manual_seed(seed)
    prepseq, shadow_state, rhoS = torch_data(f'../data/data_{N}na.pickle', shuffle=True)
    prepseq = torch.cat([prepseq+2, torch.zeros(prepseq.shape[0], 32-prepseq.shape[1], dtype=prepseq.dtype), torch.ones(prepseq.shape[0], 1, dtype=prepseq.dtype)], 1)
    train_size = int(prepseq.shape[0]*train_ratio)
    test_size = prepseq.shape[0]-train_size
    
    prepseq_train, prepseq_test = prepseq[:train_size], prepseq[train_size:]
    shadow_state_train, shadow_state_test = shadow_state[:train_size], shadow_state[train_size:]
    rhoS_train, rhoS_test = rhoS[:train_size], rhoS[train_size:]
    
    # split in batches
    prepseq_train = prepseq_train.view(-1, batch, 33)
    shadow_state_train = shadow_state_train.view(-1, batch, 4)
    rhoS_train = rhoS_train.view(-1, batch, 4, 4)

    prepseq_test = prepseq_test.view(-1, batch, 33)
    shadow_state_test = shadow_state_test.view(-1, batch, 4)
    rhoS_test = rhoS_test.view(-1, batch, 4, 4)
    
    optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-3) # 0.0001
    l = {'train Sqc':[], 'test Sqc':[], 'test Neg':[], 'test Sa':[], 'loss':[]}
    
    for epoch in range(1):
        # Train
        if train:
            print('='*50+'   Train   '+'='*50)
            mdl.train()
            for i in range(prepseq_train.shape[0]):
                rhoC = mdl(prepseq_train[i])
                l['train Sqc'].append(bSqc(rhoS_train[i], rhoC).mean().item())
                optimizer.zero_grad()
                probs = torch.bmm(torch.bmm(shadow_state_train[i].unsqueeze(1), rhoC), shadow_state_train[i].conj().unsqueeze(-1)).view(-1).real
                loss = -probs.log().mean()
                loss.backward()
                optimizer.step()
                l['loss'].append(loss.item())
                if (i+1)%100 == 0:
                    trainS = torch.tensor(l['train Sqc'])[-i:].mean().item()
                    loss = torch.tensor(l['loss'])[-i:].mean().item()
                    print('epoch:  %3d | step:  %3d | N:  %d | train Sqc: %.4f | loss: %.4f' %(epoch, i, N, trainS, loss))
        # Test
        if test:
            with torch.no_grad():
                print('='*50+'   Test   '+'='*50)
                mdl.eval()
                for i in range(prepseq_test.shape[0]):
                    rhoC = mdl(prepseq_test[i])
                    l['test Sqc'].append(bSqc(rhoS_test[i], rhoC).mean().item())
                    l['test Neg'].append(Neg(rhoS_test[i], rhoC).mean().item())
                    l['test Sa'].append(Sa(rhoS_test[i], rhoC).mean().item())
                    if (i+1)%100 == 0:
                        testS = torch.tensor(l['test Sqc'])[-i:].mean().item()
                        testN = torch.tensor(l['test Neg'])[-i:].mean().item()
                        print('epoch:  %3d | step:  %3d | N:  %d | test Sqc: %.4f | test Neg: %.4f' %(epoch, i, N, testS, testN))
        torch.save(l, f'{file}/record/gpt_N={N}_na.pt')
        torch.save(mdl.state_dict(), f'{file}/models/gpt_N={N}_na.pt')

epoch:    0 | step:   99 | N:  4 | test Sqc: 0.3762 | test Neg: 0.4104
epoch:    0 | step:  199 | N:  4 | test Sqc: 0.3720 | test Neg: 0.4133
epoch:    0 | step:   99 | N:  6 | test Sqc: 0.4154 | test Neg: 0.4030
epoch:    0 | step:  199 | N:  6 | test Sqc: 0.4094 | test Neg: 0.4043
epoch:    0 | step:   99 | N:  8 | test Sqc: 0.4981 | test Neg: 0.3727
epoch:    0 | step:  199 | N:  8 | test Sqc: 0.5023 | test Neg: 0.3720
epoch:    0 | step:   99 | N:  10 | test Sqc: 0.6339 | test Neg: 0.3277
epoch:    0 | step:  199 | N:  10 | test Sqc: 0.6330 | test Neg: 0.3280
epoch:    0 | step:   99 | N:  12 | test Sqc: 0.7257 | test Neg: 0.2898
epoch:    0 | step:  199 | N:  12 | test Sqc: 0.7160 | test Neg: 0.2953
epoch:    0 | step:   99 | N:  14 | test Sqc: 0.7497 | test Neg: 0.2795
epoch:    0 | step:  199 | N:  14 | test Sqc: 0.7440 | test Neg: 0.2824
epoch:    0 | step:   99 | N:  16 | test Sqc: 0.8113 | test Neg: 0.2559
epoch:    0 | step:  199 | N:  16 | test Sqc: 0.8178 | test Neg: 0.252