In [None]:
import torch
import wandb
from tqdm import tqdm

from data import FrameDataset, VideoDataset, get_dataloader, get_transforms
from model import FrameClassifier
from train import MetricsLogger, train_epoch, val_epoch

In [None]:
config = {
    'epochs': 30,
}

wandb.init(project='action-recognition', config=config)

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')

In [None]:
transforms = get_transforms(
    resize=256,
    crop=224,
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225],
)

train_set = FrameDataset(
    root_dir='../data',
    metadata_filename='train.csv',
    random_seed=42,
    transforms=transforms,
)

train_val_set = VideoDataset(
    root_dir='../data',
    metadata_filename='train.csv',
    random_seed=42,
    transforms=transforms,
)

val_set = VideoDataset(
    root_dir='../data',
    metadata_filename='val.csv',
    random_seed=42,
    transforms=transforms,
)

train_loader = get_dataloader(
    dataset=train_set,
    batch_size=32,
    mode='train',
    num_workers=8,
)

train_val_loader = get_dataloader(
    dataset=train_val_set,
    batch_size=1,
    mode='val',
    num_workers=8,
)

val_loader = get_dataloader(
    dataset=val_set,
    batch_size=1,
    mode='val',
    num_workers=8,
)

In [None]:
model = FrameClassifier(
    num_classes=train_set.num_classes,
).to(device)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters())

logger = MetricsLogger()

In [None]:
for epoch in tqdm(range(config['epochs'])):
    y_true, y_pred, loss = train_epoch(
        dataloader=train_loader,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
    )
    logger.log_epoch(y_true, y_pred, 'Train (Frame)')
    
    y_true, y_pred = val_epoch(
        dataloader=train_val_loader,
        model=model,
        device=device,
    )
    logger.log_epoch(y_true, y_pred, 'Train (Video)')
    
    y_true, y_pred = val_epoch(
        dataloader=val_loader,
        model=model,
        device=device,
    )
    logger.log_epoch(y_true, y_pred, 'Val (Video)')
    
    logger.log_wandb()

In [None]:
wandb.finish()