In [None]:
from troncamento_datasets import BaseDataset, SegmentPairDataset
from model import MisalignmentDetector, uncertainty
import torch, torch.nn as nn
import pandas as pd
import os
import tqdm

In [None]:
device = "mps"

model = MisalignmentDetector().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

pretrain_ckpt = "model_ckpt_pretrain.pt"
if os.path.isfile(pretrain_ckpt):
    model.load_state_dict(torch.load(pretrain_ckpt))

In [None]:
df = pd.read_csv("troncamento_data.csv")

## Pretraining

In [None]:
target_basedataset = BaseDataset(df, dataset_type="target", return_player=False)
target_dataset = SegmentPairDataset(target_basedataset)
pretrain_basedataset = BaseDataset(df, dataset_type="pre_train", return_player=False)
pretrain_dataset = SegmentPairDataset(pretrain_basedataset)

In [None]:
import utils
utils.pretrain_one_epoch(pretrain_dataset, model, optimizer, criterion, device, pretrain_ckpt)

## Active learning loop

In [None]:
utils.select_uncertain_samples(model, target_basedataset, k=3, device=device, random_sample=10000)

## Train on gold data

In [None]:
utils.train_on_gold_dataset(MisalignmentDetector, "model_ckpt", target_basedataset, device=device)

In [None]:
utils.delete_annotated_files(target_basedataset, folder_path="selected_samples")
utils.put_files_to_folder(target_basedataset, folder_path="selected_samples", tgrd_fs_folder="it_vxc_textgrids17_acoustic17")