In [1]:
import os
os.chdir("../")

from data_management.DataManager import CALFData, collateGCN
import numpy as np
import torch 
from Model import ContextAwareModel, SpottingModel, SegmentationModel
from helpers.loss import ContextAwareLoss, SpottingLoss
from modules.train import trainer
import pickle
from dataclasses import dataclass
import matplotlib.pyplot as plt
from modules.Visualiser import collateVisGCN, Visualiser
import seaborn as sns
from helpers.classes import EVENT_DICTIONARY_V2_ALIVE as event_enc
from helpers.classes import get_K_params

In [2]:
@dataclass
class Args:
    # DATA
    chunk_size = 60
    batch_size = 32
    input_channel = 13
    annotation_nr = 10
    receptive_field = 12
    fps = 5
    K_parameters = get_K_params(chunk_size)
    focused_annotation = None
    generate_augmented_data = True
    class_split = "alive"
    generate_artificial_targets = False
    
    # TRAINING
    chunks_per_epoch = 1824
    lambda_coord=5.0
    lambda_noobj=0.5
    patience=25
    LR=1e-03
    max_epochs=180
    GPU=0 
    max_num_worker=1
    loglevel='INFO'
    
    # SEGMENTATION MODULE
    feature_multiplier=1
    backbone_player = "GCN"
    load_weights=None
    model_name="Testing_Model"
    dim_capsule=16
    vocab_size=64
    pooling="NetVLAD"

    # SPOTTING MODULE
    sgementation_path = None
    freeze_model = None

In [3]:
args = Args
collate_fn = collateGCN

train_dataset = CALFData(split="train", args=args)
validation_dataset = CALFData(split="validate", args=args)

train_loader = torch.utils.data.DataLoader(train_dataset,
            batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

validate_loader = torch.utils.data.DataLoader(validation_dataset,
            batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

model = SegmentationModel(args=args)
criterion = ContextAwareLoss(K=train_dataset.K_parameters)

optimizer = torch.optim.Adam(model.parameters(), lr=args.LR, 
                            betas=(0.9, 0.999), eps=1e-07, 
                            weight_decay=0, amsgrad=False)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=args.patience)

losses = trainer(train_loader, validate_loader,
                    model, optimizer, scheduler, 
                    criterion,
                    model_name=args.model_name,
                    max_epochs=args.max_epochs, 
                    save_dir=f"models/CALF_NetVLAD_GCN.pth.tar")

del train_dataset, validation_dataset, train_loader, validate_loader

with open(f'results/CALF_NetVLAD_GCN.pkl', 'wb') as file:
    pickle.dump(losses, file)

Data preprocessing:   0%|          | 0/1 [00:00<?, ?it/s]

Data preprocessing: 100%|██████████| 1/1 [00:17<00:00, 17.99s/it]
Get labels & features: 100%|██████████| 4/4 [00:31<00:00,  7.99s/it]


In [4]:
_,_,rep = next(iter(train_loader))

In [6]:
result = model(rep)

GNN output size:  torch.Size([32, 152, 300, 1])
NetVLAD output size:  torch.Size([32, 10, 300, 1])
