In [None]:
import os
os.chdir("../")
from DataManager import CALFData, collateGCN
import numpy as np
import torch 
from Model import ContextAwareModel
from helpers.loss import ContextAwareLoss, SpottingLoss
from train import trainer
import pickle
from dataclasses import dataclass
import matplotlib.pyplot as plt
from Visualiser import collateVisGCN, Visualiser, VisualiseDataset
import seaborn as sns
from helpers.classes import EVENT_DICTIONARY_V2_ALIVE as event_enc
from helpers.classes import get_K_params
import torch.nn as nn

In [None]:
@dataclass
class Args:
    receptive_field = 12
    fps = 5
    chunks_per_epoch = 1824
    class_split = "alive"
    chunk_size = 60
    batch_size = 32
    input_channel = 13
    feature_multiplier=1
    backbone_player = "GCN"
    max_epochs=180
    load_weights=None
    model_name="Testing_Model"
    dim_capsule=16
    lambda_coord=5.0
    lambda_noobj=0.5
    patience=25
    LR=1e-03
    GPU=0 
    max_num_worker=1
    loglevel='INFO'
    annotation_nr = 1
    K_parameters = get_K_params(chunk_size)
    focused_annotation = "Duel"
    generate_augmented_data = True
    sgementation_path = "models/gridsearch5.pth.tar"
    freeze_model = True

In [None]:
args = Args
collate_fn = collateGCN
list_anns = list(event_enc.keys())

for ann in list_anns:
    print(f"\n {ann}")
    args.focused_annotation = ann
    
    # Read data for specific annotation
    train_dataset = CALFData(split="train", args=args)
    validation_dataset = CALFData(split="validate", args=args)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

    validate_loader = torch.utils.data.DataLoader(validation_dataset,
                batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

    # Load pre-trained model and adjust it
    model = torch.load(args.sgementation_path)
    model.num_classes = 1
    model.conv_seg = nn.Conv2d(in_channels=152, out_channels=model.dim_capsule, kernel_size=(model.kernel_seg_size,1))

    criterion_segmentation = ContextAwareLoss(K=train_dataset.K_parameters)
    criterion_spotting = SpottingLoss(lambda_coord=args.lambda_coord, lambda_noobj=args.lambda_noobj)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.LR, 
                                betas=(0.9, 0.999), eps=1e-07, 
                                weight_decay=0, amsgrad=False)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=args.patience)

    losses = trainer(train_loader, validate_loader,
                        model, optimizer, scheduler, 
                        [criterion_segmentation, criterion_spotting], 
                        [args.loss_weight_segmentation, args.loss_weight_detection],
                        model_name=args.model_name,
                        max_epochs=args.max_epochs, evaluation_frequency=args.evaluation_frequency,
                        save_dir=f"models/finetuned_{ann}.pth.tar")

    with open(f'results/finetuned_{ann}.pkl', 'wb') as file:
        pickle.dump(losses, file)
    
    del train_dataset,validation_dataset,train_loader, validate_loader,model