In [None]:
import json
from collections import defaultdict

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from navi.transforms import ToTensor
from navi.datasets.seq2seq_embeddings import SequenceToSequenceEmbeddings
from navi.datasets.seq2seq_embeddings_full import FullVideoEmbeddings
from navi.nn.models_using_embeddings import ResetNet50GRU
from navi.trainers.seq2seq import SequenceToSequenceTrainer

In [None]:
root = "/run/media/ppoitier/ppoitier/datasets/navi/embeddings"

def load_videos():
    videos = pd.read_csv("../data/maps/mapping.csv")
    videos = videos.sample(n=20, random_state=42, ignore_index=True)
    return videos

def load_label_map():
    with open("../prep/predictions.json", 'rb') as file:
        label_map = json.load(file)
    return label_map

In [None]:
transforms = ToTensor()
target_transform = ToTensor()

dataset = SequenceToSequenceEmbeddings(
    root=root,
    videos=load_videos(),
    label_map=load_label_map(),
    window_size=150,
    window_stride=130,
    transforms=transforms,
    target_transforms=target_transform,
    drop_empty_windows=True,
)

data_loader = DataLoader(
    dataset,
    batch_size=8,
    drop_last=True,
)

In [None]:
device = 'cuda'
model = ResetNet50GRU(input_size=2048, hidden_size=256, n_layers=1)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(3.71).cuda())
optimizer = optim.AdamW(model.parameters(), lr=2e-3)

trainer = SequenceToSequenceTrainer(
    "test_fc_resnet",
    model, criterion, optimizer,
    fold_nb=0,
    device=device,
    gradient_clipping=True,
)

In [None]:
trainer.start_wandb_logging()
trainer.launch_training(data_loader, data_loader, n_epochs=100)
trainer.stop_wandb_logging()

In [None]:
example_features, example_target = dataset[0]

plt.figure(figsize=(30, 4))
plt.scatter(range(len(example_target)), example_target)
plt.show()

In [None]:
with torch.inference_mode():
    x, y = dataset[0]
    logits = model(x.cuda().unsqueeze(0)).squeeze()
    probabilities = logits.sigmoid()

In [None]:
plt.figure(figsize=(30, 4))
plt.scatter(range(len(example_target)), example_target)
plt.plot(probabilities.cpu().numpy())
plt.show()

In [None]:
counts = defaultdict(lambda: 0)

for _, targets in dataset:
    counts[0] += (targets == 0).sum().item()
    counts[1] += (targets == 1).sum().item()

counts = dict(counts)
total = sum(counts.values())
frequencies = {k: v/total for (k, v) in counts.items()}
frequencies

In [None]:
pos_weight = 0.78 / 0.21
pos_weight