In [1]:
import os
import sys
sys.path.insert(0, "..")
import inspect
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import wandb
from safetensors import safe_open
from tqdm import tqdm

import matplotlib.pyplot as plt

from notebook_utils import *

from models import *
from my_datasets import *
from experiments import *

In [2]:
torch.manual_seed(1234)
ns_model, ns_dataset = quickload_next_state_model_and_dataset()
atk_wrapper = ForceOutputWithAppendedAttackTokensWrapper(ns_model, 3)
train_atk_dataset = ForceOutputWithAppendedAttackTokensDataset(
    num_vars = ns_model.num_labels,
    num_rules_range = (16,64),
    num_states_range = (8,32),
    ante_prob_range = (0.2, 0.5),
    conseq_prob_range = (0.2, 0.5),
    state_prob_range = (0.4, 0.5),
    dataset_len = 32768
)

num_vars = ns_model.num_labels

In [3]:
device = "cuda"
atk_wrapper.to(device).train()
optimizer = optim.Adam(atk_wrapper.parameters(), lr=5e-5)
train_dataloader = DataLoader(train_atk_dataset, batch_size=16)
num_epochs = 10

In [None]:
for e in range(1, num_epochs+1):
    num_dones, running_hits = 0, 0.0
    running_loss, running_norm_loss, running_pred_loss = 0.0, 0.0, 0.0
    running_target_ones, running_pred_ones = 0, 0
    pbar = tqdm(train_dataloader)
    print(f"Epoch {e}/{num_epochs+1}")
    for batch in pbar:
        tokens, target = batch["tokens"].to(device), batch["target"].to(device)
        out = atk_wrapper(tokens=tokens, target=target)
        pred = (out.pred > 0).long()
        loss, norm_loss, pred_loss = out.loss, out.norm_loss, out.pred_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        num_dones += tokens.size(0)
        running_hits += (pred == target).sum()
        running_loss += loss.cpu().detach().item()
        running_norm_loss += norm_loss.cpu().detach().item()
        running_pred_loss += pred_loss.cpu().detach().item()
        running_target_ones += target.sum()
        running_pred_ones += pred.sum()

        avg_acc = running_hits / (num_dones * num_vars)
        avg_loss = running_loss / num_dones
        avg_norm_loss = running_norm_loss / num_dones
        avg_pred_loss = running_pred_loss / num_dones

        avg_tones = running_target_ones / (num_dones * num_vars)
        avg_pones = running_pred_ones / (num_dones * num_vars)

        desc = f"loss {avg_loss:.3f} ({avg_norm_loss:.3f}, {avg_pred_loss:.3f}), "
        desc += f"acc {avg_acc:.3f}, tones {avg_tones:.3f}, pones {avg_pones:.3f}"
        pbar.set_description(desc)

  0%|          | 0/2048 [00:00<?, ?it/s]

Epoch 1/11


loss 124.071 (11.519, 112.552), acc 0.491, tones 0.543, pones 0.584:   9%|▉         | 180/2048 [00:17<02:46, 11.19it/s]