In [1]:
from DataManager import CALFData, collateGCN
import numpy as np
import torch 
from train import trainer
import pickle
from dataclasses import dataclass
import matplotlib.pyplot as plt
from Visualiser import collateVisGCN, Visualiser
from helpers.evaluation import segmentation_correlation
from helpers.classes import EVENT_DICTIONARY_V2_ALIVE as event_enc
from helpers.classes import get_K_params
import seaborn as sns
from SpottingModel import SpottingModel

In [4]:
@dataclass
class Args:
    receptive_field = 6
    fps = 5
    chunks_per_epoch = 1824
    class_split = "alive"
    chunk_size = 30
    batch_size = 32
    input_channel = 13
    feature_multiplier=1
    backbone_player = "GCN"
    max_epochs=180
    load_weights=None
    model_name="Testing_Model"
    dim_capsule=16
    lambda_coord=5.0
    lambda_noobj=0.5
    patience=25
    LR=1e-03
    GPU=0 
    max_num_worker=1
    loglevel='INFO'
    annotation_nr = 10
    K_parameters = get_K_params(chunk_size)
    focused_annotation = None
    generate_augmented_data = True
    sgementation_path = "models/detector_probs.pth.tar"
    freeze_model = True

In [3]:
args = Args
collate_fn = collateGCN

validation_dataset = CALFData(split="validate", args=args)
validate_loader = torch.utils.data.DataLoader(validation_dataset,
            batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

Data preprocessing:   0%|          | 0/1 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [109]:
from helpers.preprocessing import animate_clip
label, target, representation = next(iter(validation_dataset))
coords_arr = np.array([rep.x for rep in representation]).transpose((0,2,1))[15:-15]
target.shape[0]
annotation = "Pass"

animate_clip(coords_arr, target, annotation)

In [5]:
args = Args
collate_fn = collateGCN

train_dataset = CALFData(split="train", args=args)
validation_dataset = CALFData(split="validate", args=args)

train_loader = torch.utils.data.DataLoader(train_dataset,
            batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

validate_loader = torch.utils.data.DataLoader(validation_dataset,
            batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

model = SpottingModel(args=args)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.LR, 
                            betas=(0.9, 0.999), eps=1e-07, 
                            weight_decay=0, amsgrad=False)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=args.patience)

losses = trainer(train_loader, validate_loader,
                    model, optimizer, scheduler, 
                    criterion=criterion,
                    model_name=args.model_name,
                    max_epochs=args.max_epochs,
                    save_dir=f"models/spotting.pth.tar",
                    train_seg=False)

with open(f'results/spotting.pkl', 'wb') as file:
    pickle.dump(losses, file)

Data preprocessing: 100%|██████████| 10/10 [03:02<00:00, 18.23s/it]
Get labels & features: 100%|██████████| 40/40 [12:40<00:00, 19.02s/it]
Data preprocessing: 100%|██████████| 2/2 [00:35<00:00, 17.60s/it]
Get labels & features: 100%|██████████| 8/8 [02:19<00:00, 17.43s/it]
Train 1: Time 1.538s (it:1.419s) Data:0.538s (it:0.478s) Loss 3.4395e-01 : 100%|████████████████████████████████████████████████| 57/57 [01:27<00:00,  1.54s/it]
Evaluate 1: Time 1.295s (it:1.247s) Data:0.363s (it:0.336s) Loss 2.0849e-01 : 100%|█████████████████████████████████████████████| 57/57 [01:13<00:00,  1.29s/it]
Train 2: Time 1.525s (it:1.759s) Data:0.485s (it:0.473s) Loss 1.9296e-01 : 100%|████████████████████████████████████████████████| 57/57 [01:26<00:00,  1.52s/it]
Evaluate 2: Time 1.415s (it:1.427s) Data:0.404s (it:0.359s) Loss 1.6790e-01 : 100%|█████████████████████████████████████████████| 57/57 [01:20<00:00,  1.41s/it]
Train 3: Time 1.581s (it:1.439s) Data:0.516s (it:0.472s) Loss 1.7324e-01 : 100%|██

KeyboardInterrupt: 