In [1]:
from troncamento_datasets import BaseDataset, SegmentPairDataset
from model import MisalignmentDetector, uncertainty
import torch, torch.nn as nn
import pandas as pd
import os
import tqdm

In [2]:
device = "mps"

model = MisalignmentDetector().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

pretrain_ckpt = "model_ckpt_pretrain.pt"
if os.path.isfile(pretrain_ckpt):
    model.load_state_dict(torch.load(pretrain_ckpt))

In [3]:
df = pd.read_csv("troncamento_data.csv")

## Pretraining

In [4]:
target_basedataset = BaseDataset(df, dataset_type="target", return_player=False)
target_dataset = SegmentPairDataset(target_basedataset)
pretrain_basedataset = BaseDataset(df, dataset_type="pre_train", return_player=False)
pretrain_dataset = SegmentPairDataset(pretrain_basedataset)

In [7]:
from torch.utils.data import random_split, DataLoader
import utils

In [None]:
utils.pretrain_one_epoch(pretrain_dataset, model, optimizer, criterion, device, pretrain_ckpt)

## Active learning loop

In [9]:
utils.select_uncertain_samples(model, target_basedataset, k=20, device=device, random_sample=1000)

100%|██████████| 1000/1000 [02:01<00:00,  8.20it/s]


## Train on gold data

In [10]:
utils.put_files_to_folder(target_basedataset, folder_path="selected_samples", tgrd_fs_folder="it_vxc_textgrids17_acoustic17")

In [12]:
utils.train_on_gold_dataset(MisalignmentDetector, "model_ckpt", target_basedataset, device=device)

Training on 20 gold samples.


20it [00:03,  5.24it/s]


Trained on 20 gold samples: loss = 0.5851
Model checkpoint saved at  model_ckpt/model.pt


In [13]:
utils.delete_annotated_files(target_basedataset, folder_path="selected_samples")
utils.put_files_to_folder(target_basedataset, folder_path="selected_samples", tgrd_fs_folder="it_vxc_textgrids17_acoustic17")