In [1]:
%load_ext autoreload
%autoreload 2

In [42]:
from pathlib import Path
from functools import partial
from typing import Callable, List
import warnings
import pickle
import sys

import albumentations as A
import pandas as pd
import numpy as np
import cv2
import torch
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
import onnxruntime as ort

sys.path.append('../src/')
from avg_video import avg_video
from dataset import AvgVideoDataset
from model import construct_mn3_model
from utils import save2onnx

warnings.filterwarnings("ignore")

In [11]:
DATA_DIR = Path('../data/')
SRC_DIR = Path('../src/')
CLASSES = ["bridge_down", "bridge_up", "no_action", "train_in_out"]
SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [12]:
with open(SRC_DIR.joinpath('id2label.pkl'), 'rb') as fp:
    id2label = pickle.load(fp)
    label2id = {v: k for k, v in id2label.items()}

In [21]:
np.random.seed(SEED)
torch.manual_seed(SEED)
pl.seed_everything(SEED);

Global seed set to 42


# Data

In [13]:
train_clips_avg = sum([list(DATA_DIR.joinpath("train_avg", c).glob("*.jpg")) for c in CLASSES], [])
train_clips_avg = pd.DataFrame([clip.parts[-2:] for clip in train_clips_avg], columns=["label", "fname"])

train_clips_avg

Unnamed: 0,label,fname
0,bridge_down,d6be86768fa129b8.jpg
1,bridge_down,1b98537880b065ff.jpg
2,bridge_down,462ce800d75fb407.jpg
3,bridge_down,92848c13f18c8442.jpg
4,bridge_down,928cc9eaf2eb7590.jpg
...,...,...
491,train_in_out,7b11b87dd010c638.jpg
492,train_in_out,e0351f88ed52db65.jpg
493,train_in_out,32afa0ba3020e09f.jpg
494,train_in_out,5e53e235164e5386.jpg


In [14]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((180, 180)),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [15]:
train_data, val_data = train_test_split(train_clips_avg,
                                        train_size=0.8,
                                        stratify=train_clips_avg['label'],
                                        shuffle=True,
                                        random_state=SEED)

train_labels = train_data['label'].map(id2label)
val_labels = val_data['label'].map(label2id)

In [33]:
batch_size = 16

train_dataset = AvgVideoDataset(train_data, transforms=transforms)
val_dataset = AvgVideoDataset(val_data, transforms=transforms)


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [34]:
freeze_pretrained = True

# Define model
model = construct_mn3_model(freeze_pretrained)
model.to(DEVICE)

MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), 

In [35]:
lr = 1e-4
weight_decay = 0.01

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=lr, 
                              weight_decay=weight_decay)

In [36]:
# Training loop
num_epochs = 5
for epoch in tqdm(range(num_epochs)):
    model.train()  # Set the model to training mode
    train_loss = 0.0

    for images, labels in tqdm(train_dataloader, desc='Training'):
        optimizer.zero_grad()

        images, labels = images.to(DEVICE), labels.to(DEVICE)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update statistics
        train_loss += loss.item()
        # predicted = outputs.argmax(1)

    train_loss /= len(train_dataloader)

    # Validation loop
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    val_preds = []
    with torch.inference_mode():
        for images, labels in tqdm(val_dataloader, desc='Validating'):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Update statistics
            val_loss += loss.item()
            predicted = outputs.argmax(1)
            val_preds.extend(list(predicted.cpu().numpy()))

        val_f1 = f1_score(val_labels, val_preds, average='macro')

    val_loss /= len(val_dataloader)
    
    print(f'Epoch: {epoch}, Val F1: {val_f1}, Train loss: {train_loss}, Val loss: {val_loss}')


  0%|          | 0/5 [00:00<?, ?it/s]

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch: 0, Val F1: 0.19135802469135801, Train loss: 0.9660446810722351, Val loss: 0.7973831636565072


Training:   0%|          | 0/25 [00:00<?, ?it/s]

Validating:   0%|          | 0/7 [00:00<?, ?it/s]

Epoch: 1, Val F1: 0.5706372549019608, Train loss: 0.5595075488090515, Val loss: 0.5292345583438873


Training:   0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [26]:
from torchmetrics import F1Score
from typing import Tuple, NoReturn, Union, Dict

In [39]:
class AvgVideoDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_data: pd.DataFrame,
                 val_data: pd.DataFrame,
                 batch_size: int = 16,
                 resize: Tuple[int, int] = (180, 180),
                 num_workers: int = 0):
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.batch_size = batch_size
        self.resize = resize
        self.num_workers = num_workers

    def setup(self, stage: Union[None, str]) -> NoReturn:
        if stage == 'fit' or stage is None:
            transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Resize(self.resize),
                torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])
            ])
            self.train_dataset = AvgVideoDataset(self.train_data, transforms)
            self.val_dataset = AvgVideoDataset(self.val_data, transforms)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset,
                          batch_size=self.batch_size,
                          shuffle=True, 
                          num_workers=self.num_workers)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_dataset,
                          batch_size=self.batch_size,
                          shuffle=False, 
                          num_workers=self.num_workers)


In [45]:
class MobileNetV3(pl.LightningModule):
    def __init__(self, 
                 freeze_pretrained: bool = True,
                 lr: float = 1e-4,
                 weight_decay: float = 0.01):
        super().__init__()
        self.model = construct_mn3_model(freeze_pretrained)
        self.lr = lr
        self.weight_decay = weight_decay

    def setup(self, stage: Union[None, str]) -> NoReturn:
        if stage == 'fit':
            # Metrics
            self.f1_train = F1Score(task='multiclass', threshold=0.5, num_classes=4, average='macro')
            self.f1_val = F1Score(task='multiclass', threshold=0.5, num_classes=4, average='macro')
            # Loss fn
            self.loss_fn = torch.nn.CrossEntropyLoss()


    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        return self.model(batch)
    
    def training_step(self, 
                      batch: Tuple[torch.Tensor, int], 
                      batch_idx: int):
        images, labels = batch
        logits = self(images)
        loss = self.loss_fn(logits, labels)
        self.log('loss_train', loss,
                 on_epoch=True, prog_bar=True)
        return {'loss': loss, 
                'logits': logits, 
                'labels': labels}
    
    def training_step_end(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        self.f1_train.update(outputs['logits'], outputs['labels'])
        self.log('f1_train', self.f1_train,
                 on_epoch=True, prog_bar=True)
        loss = torch.mean(outputs['loss'])
        return loss
    
    def validation_step(self, 
                        batch: Tuple[torch.Tensor, int], 
                        batch_idx: int):
        images, labels = batch
        logits = self(images)
        loss = self.loss_fn(logits, labels)
        self.log('loss_val', loss,
                 on_epoch=True, prog_bar=True)
        return {'loss': loss, 
                'logits': logits, 
                'labels': labels}

    def validation_step_end(self, outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        self.f1_val.update(outputs['logits'], outputs['labels'])
        self.log('f1_val', self.f1_val,
                 on_epoch=True, prog_bar=True)
        loss = torch.mean(outputs['loss'])
        return loss

    def configure_optimizers(self) -> torch.optim.Optimizer:
        optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.lr, 
            weight_decay=self.weight_decay
        )
        return optimizer

In [46]:
datamodule = AvgVideoDataModule(train_data, val_data)
model = MobileNetV3()

In [47]:
early_stopping = EarlyStopping(monitor='f1_val',
                               mode='min',
                               patience=3)

checkpoint = ModelCheckpoint(monitor='f1_val',
                             dirpath=)