In [1]:
import os
import sys
sys.path.insert(0, "..")
import inspect
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import wandb
from safetensors import safe_open
from tqdm import tqdm

import matplotlib.pyplot as plt

from notebook_utils import *

from models import *
from my_datasets import *
from experiments import *

In [2]:
torch.manual_seed(1234)
num_atk_tokens = 5
ns_model, ns_dataset = quickload_next_state_model_and_dataset()
atk_wrapper = ForceOutputWithAppendedAttackTokensWrapper(ns_model, num_atk_tokens)
train_atk_dataset = ForceOutputWithAppendedAttackTokensDataset(
    num_vars = ns_model.num_labels,
    num_rules_range = (16,64 - num_atk_tokens),
    num_states_range = (8,32),
    ante_prob_range = (0.2, 0.5),
    conseq_prob_range = (0.2, 0.5),
    state_prob_range = (0.4, 0.5),
    dataset_len = 8192
)

num_vars = ns_model.num_labels

In [3]:
device = "cuda"
atk_wrapper.to(device).train()
optimizer = optim.Adam(atk_wrapper.parameters(), lr=5e-5)
train_dataloader = DataLoader(train_atk_dataset, batch_size=128)
num_epochs = 50

In [4]:
for e in range(1, num_epochs+1):
    num_dones, running_hits = 0, 0.0
    running_loss, running_norm_loss, running_pred_loss = 0.0, 0.0, 0.0
    running_target_ones, running_pred_ones = 0, 0
    pbar = tqdm(train_dataloader)
    print(f"Epoch {e}/{num_epochs}")
    for batch in pbar:
        tokens, target = batch["tokens"].to(device), batch["target"].to(device)
        out = atk_wrapper(tokens=tokens, target=target)
        pred = (out.pred > 0).long()
        loss, norm_loss, pred_loss = out.loss, out.norm_loss, out.pred_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        num_dones += tokens.size(0)
        running_hits += (pred == target).sum()
        running_loss += loss.cpu().detach().item()
        running_norm_loss += norm_loss.cpu().detach().item()
        running_pred_loss += pred_loss.cpu().detach().item()
        running_target_ones += target.sum()
        running_pred_ones += pred.sum()

        avg_acc = running_hits / (num_dones * num_vars)
        avg_loss = running_loss / num_dones
        avg_norm_loss = running_norm_loss / num_dones
        avg_pred_loss = running_pred_loss / num_dones

        avg_tones = running_target_ones / (num_dones * num_vars)
        avg_pones = running_pred_ones / (num_dones * num_vars)

        desc = f"loss {avg_loss:.3f} ({avg_norm_loss:.3f}, {avg_pred_loss:.3f}), "
        desc += f"acc {avg_acc:.3f}, tones {avg_tones:.3f}, pones {avg_pones:.3f}"
        pbar.set_description(desc)

  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 1/50


loss 146.232 (18.135, 128.097), acc 0.492, tones 0.544, pones 0.598: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 2/50


loss 108.697 (9.337, 99.360), acc 0.488, tones 0.544, pones 0.575: 100%|██████████| 64/64 [00:40<00:00,  1.59it/s]  
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 3/50


loss 97.749 (5.571, 92.179), acc 0.504, tones 0.544, pones 0.637: 100%|██████████| 64/64 [00:39<00:00,  1.60it/s] 
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 4/50


loss 94.653 (3.918, 90.735), acc 0.514, tones 0.545, pones 0.687: 100%|██████████| 64/64 [00:40<00:00,  1.58it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 5/50


loss 93.106 (3.059, 90.047), acc 0.517, tones 0.544, pones 0.707: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 6/50


loss 92.803 (2.805, 89.998), acc 0.516, tones 0.544, pones 0.710: 100%|██████████| 64/64 [00:41<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 7/50


loss 92.280 (2.499, 89.780), acc 0.517, tones 0.543, pones 0.726: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 8/50


loss 92.073 (2.391, 89.682), acc 0.519, tones 0.544, pones 0.730: 100%|██████████| 64/64 [00:41<00:00,  1.55it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 9/50


loss 91.571 (2.138, 89.433), acc 0.522, tones 0.544, pones 0.743: 100%|██████████| 64/64 [00:41<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 10/50


loss 91.442 (1.967, 89.475), acc 0.521, tones 0.544, pones 0.747: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 11/50


loss 91.521 (2.015, 89.506), acc 0.521, tones 0.544, pones 0.749: 100%|██████████| 64/64 [00:42<00:00,  1.52it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 12/50


loss 91.471 (2.015, 89.457), acc 0.520, tones 0.543, pones 0.746: 100%|██████████| 64/64 [00:40<00:00,  1.58it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 13/50


loss 91.145 (1.797, 89.348), acc 0.523, tones 0.544, pones 0.760: 100%|██████████| 64/64 [00:41<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 14/50


loss 91.659 (2.140, 89.518), acc 0.520, tones 0.543, pones 0.736: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 15/50


loss 91.060 (1.709, 89.351), acc 0.522, tones 0.544, pones 0.766: 100%|██████████| 64/64 [00:40<00:00,  1.58it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 16/50


loss 91.273 (1.864, 89.409), acc 0.522, tones 0.544, pones 0.751: 100%|██████████| 64/64 [00:41<00:00,  1.55it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 17/50


loss 90.681 (1.561, 89.120), acc 0.525, tones 0.545, pones 0.778: 100%|██████████| 64/64 [00:41<00:00,  1.54it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 18/50


loss 90.686 (1.527, 89.160), acc 0.525, tones 0.544, pones 0.781: 100%|██████████| 64/64 [00:41<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 19/50


loss 90.605 (1.516, 89.088), acc 0.525, tones 0.545, pones 0.784: 100%|██████████| 64/64 [00:40<00:00,  1.59it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 20/50


loss 90.543 (1.428, 89.115), acc 0.525, tones 0.544, pones 0.784: 100%|██████████| 64/64 [00:40<00:00,  1.58it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 21/50


loss 90.573 (1.490, 89.082), acc 0.524, tones 0.544, pones 0.784: 100%|██████████| 64/64 [00:41<00:00,  1.55it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 22/50


loss 90.742 (1.567, 89.175), acc 0.524, tones 0.544, pones 0.780: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 23/50


loss 90.575 (1.425, 89.150), acc 0.523, tones 0.544, pones 0.778: 100%|██████████| 64/64 [00:40<00:00,  1.58it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 24/50


loss 90.454 (1.457, 88.998), acc 0.526, tones 0.544, pones 0.788: 100%|██████████| 64/64 [00:40<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 25/50


loss 90.430 (1.371, 89.060), acc 0.525, tones 0.544, pones 0.788: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 26/50


loss 90.565 (1.451, 89.114), acc 0.524, tones 0.544, pones 0.777: 100%|██████████| 64/64 [00:40<00:00,  1.57it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 27/50


loss 90.296 (1.301, 88.995), acc 0.526, tones 0.544, pones 0.798: 100%|██████████| 64/64 [00:40<00:00,  1.59it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 28/50


loss 90.804 (1.599, 89.206), acc 0.524, tones 0.544, pones 0.778: 100%|██████████| 64/64 [00:40<00:00,  1.58it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 29/50


loss 90.802 (1.533, 89.269), acc 0.523, tones 0.544, pones 0.764: 100%|██████████| 64/64 [00:40<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 30/50


loss 90.225 (1.271, 88.954), acc 0.526, tones 0.544, pones 0.798: 100%|██████████| 64/64 [00:39<00:00,  1.61it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 31/50


loss 90.619 (1.464, 89.155), acc 0.525, tones 0.543, pones 0.787: 100%|██████████| 64/64 [00:41<00:00,  1.55it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 32/50


loss 90.784 (1.558, 89.226), acc 0.523, tones 0.544, pones 0.772: 100%|██████████| 64/64 [00:41<00:00,  1.55it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 33/50


loss 90.647 (1.449, 89.198), acc 0.525, tones 0.545, pones 0.777: 100%|██████████| 64/64 [00:41<00:00,  1.55it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 34/50


loss 90.440 (1.413, 89.027), acc 0.525, tones 0.543, pones 0.795: 100%|██████████| 64/64 [00:41<00:00,  1.56it/s]
  0%|          | 0/64 [00:00<?, ?it/s]

Epoch 35/50


loss 90.311 (1.195, 89.116), acc 0.524, tones 0.544, pones 0.785:  44%|████▍     | 28/64 [00:18<00:23,  1.54it/s]


KeyboardInterrupt: 