In [16]:
import os

import torch
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, random_split
# from torchvision.transforms import functional as F
from torch.nn import functional as F
import torchvision.transforms.functional as VF
# from pycocotools.coco import COCO
import torchvision.transforms.v2 as T
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import transforms

import matplotlib.pyplot as plt
from itertools import cycle

from tqdm.notebook import tqdm

import random
import numpy as np
import torchvision.models as models

import wandb


import import_ipynb

import math

In [17]:
wandb.login()



True

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if torch.mps.is_available():
    device = "mps"
print(device)

# reduce cpu contention
torch.set_num_threads(1)
NUM_WORKERS = 6  # adjust based on CPU cores

mps


In [19]:
NUM_CLASSES = 51
EPOCHS = 100
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 16

CROP_SIZE = (256, 256)
DATA_AUGMENTATION = False

In [20]:
def train_one_epoch(model, dataloader, optimizer, criterion, verbose_tqdm=False):
    model.train()
    total_loss = 0.0

    dl = tqdm(dataloader, desc="Training") if verbose_tqdm else dataloader    
    for imgs, keypoints in dl:
        imgs = imgs.to(device)
        keypoints = keypoints.to(device)
        
        optimizer.zero_grad()
        
        preds = model(imgs)
        loss = criterion(preds, keypoints)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, verbose_tqdm=False):
    model.eval()
    total_loss = 0.0

    dl = tqdm(dataloader, desc="Evaluating") if verbose_tqdm else dataloader
    with torch.no_grad():
        for imgs, keypoints in dl:
            imgs = imgs.to(device)
            keypoints = keypoints.to(device)
            
            preds = model(imgs)
            loss = criterion(preds, keypoints)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [21]:
class FusionDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, keypoint_model, device):
        self.base_dataset = base_dataset
        self.keypoint_model = keypoint_model.eval().to(device)
        self.device = device

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        frames, label = self.base_dataset[idx]  # frames: (16, 3, 224, 224)
        frames = frames.to(self.device)

        keypoints_seq = []
        with torch.no_grad():
            for frame in frames:
                keypoints = self.keypoint_model(frame.unsqueeze(0))  # (1, 3, 224, 224)
                keypoints_seq.append(keypoints.squeeze(0))  # (17, 2)

        keypoints_tensor = torch.stack(keypoints_seq)  # shape: (16, 17, 2)

        return (frames.cpu(), keypoints_tensor.cpu()), label


In [22]:


class ActionsFusionModel(nn.Module):
    def __init__(self, num_keypoints=17, num_actions=10):
        super().__init__()

        # === Wizualny tor (CNN) ===
        base_model = models.resnet18(pretrained=True)
        self.cnn_backbone = nn.Sequential(*list(base_model.children())[:-1])
        self.feature_dim_img = base_model.fc.in_features  # 512

        for param in self.cnn_backbone.parameters():
            param.requires_grad = False

        # === Tor dla keypointów === TODO: Tu trzeba jeszce ogarnąć jakie jest wejście - keypoint_dim
        self.keypoint_dim = num_keypoints * 2  # (x,y) dla każdego punktu

        self.keypoint_mlp = nn.Sequential(
            nn.Linear(self.keypoint_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 128),
            nn.ReLU()
        )

        # === Klasyfikator na podstawie fuzji ===
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim_img + 128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 256),
            nn.ReLU(),

            nn.Linear(256, num_actions)
        )

    def forward(self, x_img, x_kp):
        # === Obrazy ===
        B, T, C, H, W = x_img.shape
        x_img = x_img.view(B * T, C, H, W)

        with torch.no_grad():
            feat_img = self.cnn_backbone(x_img)  # (B*T, 512, 1, 1)
        feat_img = feat_img.view(B, T, self.feature_dim_img)
        feat_img = feat_img.mean(dim=1)  # (B, 512)

        # === Keypointy ===
        B, T, N, _ = x_kp.shape
        x_kp = x_kp.view(B, T, -1)           # (B, T, N*2)
        feat_kp = self.keypoint_mlp(x_kp)    # (B, T, 128)
        feat_kp = feat_kp.mean(dim=1)        # (B, 128)

        # === Fuzja ===
        fused = torch.cat([feat_img, feat_kp], dim=1)  # (B, 640)

        out = self.classifier(fused)  # (B, num_actions)
        return out


In [23]:
import sys
sys.path.append("..")

import actions.data_loader as dl
# import keypoints.keypoints_boundingbox_approach as kp



keypoint_model = None
# keypoint_model = kp.KeypointCropModel().to(device)
# keypoint_model.load_state_dict(torch.load("../../models/bb_23loss_keypoint_crop_model.pth", map_location=device))
# keypoint_model.to(device)
# keypoint_model.eval()

In [24]:
def extract_keypoints_batch(model, imgs):
    """
    imgs: (B, T, 3, H, W)
    Zwraca: (B, T, N, 2) – keypointy
    """
    B, T, C, H, W = imgs.shape
    imgs = imgs.to(device)

    keypoints_list = []

    for b in range(B):
        sample_keypoints = []
        for t in range(T):
            frame = imgs[b, t].unsqueeze(0)  # (1, 3, H, W)
            kp = model(frame)  # np. (1, N, 2)
            sample_keypoints.append(kp.squeeze(0).cpu())  # (N, 2)
        sample_keypoints = torch.stack(sample_keypoints, dim=0)  # (T, N, 2)
        keypoints_list.append(sample_keypoints)

    keypoints_tensor = torch.stack(keypoints_list, dim=0)  # (B, T, N, 2)
    return keypoints_tensor

In [25]:
class FusionWrapperDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, keypoint_model):
        self.base = base_dataset
        self.keypoint_model = keypoint_model

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        imgs, label = self.base[idx]  # imgs: (T, 3, H, W)
        imgs = imgs.unsqueeze(0)  # (1, T, 3, H, W)

        keypoints = extract_keypoints_batch(self.keypoint_model, imgs)  # (1, T, N, 2)
        keypoints = keypoints.squeeze(0)  # (T, N, 2)

        return imgs.squeeze(0), keypoints, label

In [26]:
fusion_train = FusionWrapperDataset(dl.dataset_train, keypoint_model)
fusion_val   = FusionWrapperDataset(dl.dataset_valid, keypoint_model)
fusion_test  = FusionWrapperDataset(dl.dataset_test, keypoint_model)

train_loader_fused = DataLoader(fusion_train, batch_size=BATCH_SIZE, shuffle=True)
val_loader_fused   = DataLoader(fusion_val, batch_size=BATCH_SIZE, shuffle=False)
test_loader_fused  = DataLoader(fusion_test, batch_size=BATCH_SIZE, shuffle=False)

# Model
model = ActionsFusionModel(num_keypoints=17, num_classes=51).to(device)

TypeError: ActionsFusionModel.__init__() got an unexpected keyword argument 'num_classes'

In [None]:
criterion = nn.CrossEntropyLoss();
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [None]:
wandb_config = {
    "epochs": EPOCHS,
    "learning_rate": LEARNING_RATE,
    "weight_decay": WEIGHT_DECAY,
    "batch_size": BATCH_SIZE,
    "train_size": len(fusion_train),
    "val_size": len(fusion_val),
    "test_size": len(fusion_test),
    "model": "ActionsBaselineModel",
    "criterion": "Cross entropy loss",
    "optimizer": "Adam",
    "crop_size": CROP_SIZE,
    "device": device,
    "data_augmentation": DATA_AUGMENTATION
}

wandb.init(
    entity="fejowo5522-",
    project="NN_Project",
    config=wandb_config,
    group="ActionsBaseline"
)

In [None]:
verbose_tqdm = True
early_stopping = True
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

train_losses = []
val_losses = []

In [None]:
from tqdm.auto import tqdm

for epoch in tqdm(range(EPOCHS)):
    # print(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Train
    train_loss = train_one_epoch(model, dl.train_loader, optimizer, criterion, verbose_tqdm=verbose_tqdm)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = evaluate(model, dl.val_loader, criterion, verbose_tqdm=verbose_tqdm)
    val_losses.append(val_loss)
    
    # print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Log to wandb
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_loss': val_loss
    })
    
    # Early stopping
    if early_stopping:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break