In [None]:
from TFM import LlamaPredictor
import torch
from utils import torch_data, shuffle, blogm, bSqc, Neg, Sa, eps, create_train_test_split, save_checkpoint, load_checkpoint, save_checkpoint_and_test
from math import prod

dtype = torch.complex128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
mdl = LlamaPredictor(L_max=36,
                     d=4,
                     n_embd=96, 
                     n_layer=36, 
                     n_head=48,
                     vocab_size=3, 
                     dropout_prob=0.0).to(device)
paras = 0
for p in mdl.parameters():
    paras += prod(p.shape)
paras

2333120

In [None]:
seed = 81
test_size = 1*10**6
N = 36
batch = 1000
file = 'save'
num_check = 20  # Number of checkpoints to save per epoch 
start_epoch = 0  # Epoch to start/resume from
for d in [5]:
    for theta_idx in [4]:
        for train_size in [81*10**6]:
            torch.manual_seed(seed)
            mdl = LlamaPredictor(L_max=N,
                                    d=d,
                                    n_embd=96, 
                                    n_layer=36, 
                                    n_head=48,
                                    vocab_size=3, 
                                    dropout_prob=0.0).to(device)
            # mdl = LlamaPredictor(L_max=N,
            #                         d=d,
            #                         n_embd=12, 
            #                         n_layer=6, 
            #                         n_head=2,
            #                         vocab_size=3, 
            #                         dropout_prob=0.0).to(device) # light model for quick testing
            optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-4) # 0.0001
            l_train = {'msk on Sqc':[], 'loss':[]} # record mean train metrics at checkpoint saves
            l_test = {'loss':[], 'msk off Sqc':[], 'msk off Neg':[],'msk off Sa':[]} # record mean test metrics at checkpoint saves
            
            # Temporary storage for accumulating values between checkpoints
            temp_train = {'msk on Sqc':[], 'loss':[]}
            temp_test_test = {'loss':[], 'msk off Sqc':[], 'msk off Neg':[], 'msk off Sa':[]}
            
            # load train/test data
            prepseq_all = torch.load(f'data/theta{theta_idx}/all_prepseq_theta={theta_idx}.pt',weights_only=True)
            shadow_all = torch.load(f'data/theta{theta_idx}/all_shadow_state_theta={theta_idx}.pt',weights_only=True)
            rhoS_all = torch.load(f'data/theta{theta_idx}/all_rhoS_theta={theta_idx}.pt',weights_only=True)
            
            # Preprocess all prepseq data once (add 1 and append zero column)
            prepseq_all = torch.cat([prepseq_all+1, torch.zeros(prepseq_all.shape[0], 1, dtype=prepseq_all.dtype)], -1)
            
            # Create non-overlapping train/test split with batching
            train_data, test_data = create_train_test_split(
                prepseq_all, shadow_all, rhoS_all, 
                train_size, test_size, batch
            )
            
            # Extract batched data for convenience
            prepseq_train = train_data['prepseq']
            shadow_state_train = train_data['shadow_state']
            rhoS_train = train_data['rhoS']
            
            prepseq_test = test_data['prepseq']
            shadow_state_test = test_data['shadow_state']
            rhoS_test = test_data['rhoS']


            # Calculate checkpoint saving interval
            total_batches = prepseq_train.shape[0]
            save_interval = max(1, total_batches // num_check)  # Ensure at least 1
            print(f'Will save checkpoints every {save_interval} batches ({num_check} times per epoch)')

            # load checkpoint (resume from previous epoch's final checkpoint)
            if start_epoch > 0:
                # Load the final checkpoint from the previous epoch
                prev_epoch = start_epoch - 1
                final_checkpoint_num = (total_batches - 1) // save_interval  # Last checkpoint of previous epoch
                checkpoint_info = load_checkpoint(mdl, optimizer, prev_epoch, final_checkpoint_num, 
                                                save_dir=f'{file}/models', 
                                                filename_prefix=f'model_d{d}_theta_idx{theta_idx}')
                print(f"Resumed from epoch {prev_epoch}, checkpoint {final_checkpoint_num}. Starting epoch {start_epoch}.")
                
                # Load previous training and test records
                try:
                    l_train = torch.load(f'{file}/record/epoch={prev_epoch}_d={d}_theta_idx={theta_idx}_size{train_size}_train.pt', weights_only=True)
                    l_test = torch.load(f'{file}/record/epoch={prev_epoch}_d={d}_theta_idx={theta_idx}_size{train_size}_test.pt', weights_only=True)
                    print(f"Loaded training records up to epoch {prev_epoch}. Train points: {len(l_train['loss'])}, Test points: {len(l_test['msk off Sqc'])}")
                except FileNotFoundError as e:
                    print(f"Warning: Could not load previous records: {e}")
                    print("Starting with empty records.")

            # Save baseline checkpoint for fresh training runs (ensures consistent starting point)
            if start_epoch == 0:
                print('Saving baseline checkpoint...')
                save_checkpoint_and_test(mdl, optimizer, -1, 0,
                                        temp_train, temp_test_test,
                                        l_train, l_test, 
                                        prepseq_train, shadow_state_train, rhoS_train,
                                        prepseq_test, shadow_state_test, rhoS_test,
                                        device, f'{file}/models', f'model_d{d}_theta_idx{theta_idx}', 
                                        d, theta_idx, num_check)
            # Enable deterministic behavior for CUDA operations (may impact performance)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            for epoch in range(start_epoch, 10):
                # Set manual seed for reproducible training (after checkpoint loading)
                # This ensures consistent randomness whether starting fresh or resuming
                torch.manual_seed(seed + epoch)  # Offset by start_epoch for consistency
                if torch.cuda.is_available():
                    torch.cuda.manual_seed(seed + epoch)
                    torch.cuda.manual_seed_all(seed + epoch)
                # train loop
                print('='*50+'   Train   '+'='*50)
                mdl.train()
                
                # Shuffle individual samples (not batches) at the beginning of each epoch
                # Flatten to individual samples
                prepseq_flat = prepseq_train.view(-1, prepseq_train.shape[-1])
                shadow_state_flat = shadow_state_train.view(-1, shadow_state_train.shape[-1])
                rhoS_flat = rhoS_train.view(-1, rhoS_train.shape[-2], rhoS_train.shape[-1])
                
                # Use existing shuffle helper function
                prepseq_shuffled, shadow_state_shuffled, rhoS_shuffled = shuffle(prepseq_flat, shadow_state_flat, rhoS_flat)
                
                # Re-batch the shuffled samples
                prepseq_train_shuffled = prepseq_shuffled.view(prepseq_train.shape)
                shadow_state_train_shuffled = shadow_state_shuffled.view(shadow_state_train.shape)
                rhoS_train_shuffled = rhoS_shuffled.view(rhoS_train.shape)
                
                for i in range(prepseq_train.shape[0]):
                    prepseq_batch, shadow_state_batch, rhoS_batch = prepseq_train_shuffled[i].clone(), shadow_state_train_shuffled[i].clone(), rhoS_train_shuffled[i].clone()
                    prepseq_batch = prepseq_batch.to(device)
                    shadow_state_batch = shadow_state_batch.to(device)
                    rhoS_batch = rhoS_batch.to(device)
                    rhoC = mdl(prepseq_batch, True)
                    # Train
                    optimizer.zero_grad()
                    probs = torch.bmm(torch.bmm(shadow_state_batch.conj().unsqueeze(1), rhoC), shadow_state_batch.unsqueeze(-1)).view(-1).real
                    loss = -probs.log().mean()
                    loss.backward()
                    optimizer.step()
                    temp_train['loss'].append(loss.item())
                    temp_train['msk on Sqc'].extend(bSqc(rhoS_batch, rhoC).tolist())
                    # Save checkpoint and run test at regular intervals
                    if (i+1) % save_interval == 0:
                        checkpoint_num = (i+1) // save_interval - 1
                        save_checkpoint_and_test(mdl, optimizer, epoch, checkpoint_num, 
                                                temp_train, temp_test_test,
                                                l_train, l_test, 
                                                prepseq_test, shadow_state_test, rhoS_test,
                                                device, f'{file}/models', f'model_d{d}_theta_idx{theta_idx}', 
                                                d, theta_idx, num_check)
                    
                    if (i+1)%100 == 0 and temp_train['msk on Sqc'] and temp_train['loss']:
                        trainS = torch.tensor(temp_train['msk on Sqc']).mean().item()
                        loss_mean = torch.tensor(temp_train['loss']).mean().item()
                        print('epoch:  %3d | step:  %3d |  d:  %3d | theta_idx:  %3d | current Sqc: %.4f | current loss: %.4f' %(epoch, i, d, theta_idx, trainS, loss_mean))
                # Save final checkpoint at end of epoch (if not already saved)
                if total_batches % save_interval != 0 and (temp_train['loss'] or temp_train['msk on Sqc']):
                    final_checkpoint_num = total_batches // save_interval
                    save_checkpoint_and_test(mdl, optimizer, epoch, final_checkpoint_num, 
                                            temp_train, temp_test_test,
                                            l_train, l_test, 
                                            prepseq_test, shadow_state_test, rhoS_test,
                                            device, f'{file}/models', f'model_d{d}_theta_idx{theta_idx}', 
                                            d, theta_idx, num_check, is_final=True)
                
                torch.save(l_train, f'{file}/record/epoch={epoch}_d={d}_theta_idx={theta_idx}_size{train_size}_train.pt')
                torch.save(l_test, f'{file}/record/epoch={epoch}_d={d}_theta_idx={theta_idx}_size{train_size}_test.pt')