In [61]:
import os

import torch
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, random_split
# from torchvision.transforms import functional as F
from torch.nn import functional as F
import torchvision.transforms.functional as VF
# from pycocotools.coco import COCO
import torchvision.transforms.v2 as T
import torch.nn as nn
from torchvision.models import resnet18
from torchvision import transforms

import matplotlib.pyplot as plt
from itertools import cycle

from tqdm.notebook import tqdm

import random
import numpy as np
import torchvision.models as models

import wandb


import import_ipynb

import math

In [62]:
wandb.login()



True

In [63]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if torch.mps.is_available():
    device = "mps"
print(device)

# reduce cpu contention
torch.set_num_threads(1)
NUM_WORKERS = 6  # adjust based on CPU cores

mps


In [64]:
NUM_CLASSES = 3

In [65]:
def train_one_epoch(model, dataloader, optimizer, criterion, verbose_tqdm=False):
    model.train()
    total_loss = 0.0

    dl = tqdm(dataloader, desc="Training") if verbose_tqdm else dataloader    
    for imgs, keypoints in dl:
        imgs = imgs.to(device)
        keypoints = keypoints.to(device)
        
        optimizer.zero_grad()
        
        preds = model(imgs)
        loss = criterion(preds, keypoints)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, verbose_tqdm=False):
    model.eval()
    total_loss = 0.0

    dl = tqdm(dataloader, desc="Evaluating") if verbose_tqdm else dataloader
    with torch.no_grad():
        for imgs, keypoints in dl:
            imgs = imgs.to(device)
            keypoints = keypoints.to(device)
            
            preds = model(imgs)
            loss = criterion(preds, keypoints)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [66]:
class ActionsBaselineModel(nn.Module):
    def __init__(self, num_actions=10):
        super().__init__()
        # Pretrained ResNet bez ostatniej warstwy FC
        base_model = models.resnet18(pretrained=True)
        self.cnn_backbone = nn.Sequential(*list(base_model.children())[:-1])  # usuwa fc
        self.feature_dim = base_model.fc.in_features  # 512 dla resnet18

        for param in self.cnn_backbone.parameters():
            param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 256),
            nn.ReLU(),

            nn.Linear(256, num_actions)
        )

    def forward(self, x):
        # x: (B, 16, 3, H, W)
        B, T, C, H, W = x.size()
        x = x.view(B * T, C, H, W)

        # Backbone (bez gradientów, więc niepotrzebne requires_grad=True)
        with torch.no_grad():
            features = self.cnn_backbone(x)  # (B*T, 512, 1, 1)
        features = features.view(B, T, self.feature_dim)  # (B, T, 512)

        features = features.mean(dim=1)  # (B, 512)

        out = self.classifier(features)  # (B, num_actions)
        return out

In [67]:
model = ActionsBaselineModel(NUM_CLASSES).to(device);



In [68]:
EPOCHS = 50
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
BATCH_SIZE = 4

CROP_SIZE = (256, 256)
DATA_AUGMENTATION = False


criterion = nn.CrossEntropyLoss();
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [69]:
import sys
sys.path.append("..")

import actions.data_loader_simple as dl

train_loader = DataLoader(dl.dataset_train, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(dl.dataset_valid, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
test_loader = DataLoader(dl.dataset_test, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)


In [70]:
wandb_config = {
    "epochs": EPOCHS,
    "learning_rate": LEARNING_RATE,
    "weight_decay": WEIGHT_DECAY,
    "batch_size": BATCH_SIZE,
    "train_size": len(dl.dataset_train),
    "val_size": len(dl.dataset_valid),
    "test_size": len(dl.dataset_test),
    "model": "ActionsBaselineModel",
    "criterion": "Cross entropy loss",
    "optimizer": "Adam",
    "crop_size": CROP_SIZE,
    "device": device,
    "data_augmentation": DATA_AUGMENTATION
}

wandb.init(
    entity="fejowo5522-",
    project="NN_Project",
    config=wandb_config,
    group="ActionsBaseline"
)

Python(15712) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [71]:
verbose_tqdm = True
early_stopping = True
patience = 20
best_val_loss = float('inf')
epochs_no_improve = 0

train_losses = []
val_losses = []


In [72]:
from tqdm.auto import tqdm

for epoch in tqdm(range(EPOCHS)):
    # print(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Train
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, verbose_tqdm=verbose_tqdm)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = evaluate(model, val_loader, criterion, verbose_tqdm=verbose_tqdm)
    val_losses.append(val_loss)
    
    # print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Log to wandb
    wandb.log({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'val_loss': val_loss
    })
    
    # Early stopping
    if early_stopping:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break


  0%|          | 0/50 [00:00<?, ?it/s]

Training:   0%|          | 0/53 [00:00<?, ?it/s]

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 1024])