In [46]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as transforms
from torchvision.models.video import r2plus1d_18, R2Plus1D_18_Weights
from os import path

from train_model import TrainModel
from video_tensor_dataset import VideoTensorDataset

BATCH_SIZE = 8
NUM_EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {DEVICE}")

Using device cuda


In [47]:
MY_NAME_IS = 'Nitzan'

DATA_FOLDER_DICT = {
    'Victor': ['E:\DeepFakeDetection\dfdc_train_all','E:\DeepFakeDetection\smalldata'],
    'Nitzan': ['D:\dfdc','D:\dfdc_small4'],
    'Netanel':['F:\input','F:\input']
  }

BIG_DATA_FOLDER, SMALL_DATA_FOLDER = DATA_FOLDER_DICT[MY_NAME_IS]

TRAIN_PARTS = list(range(8))
VALIDATION_PARTS = [8, 9]

In [48]:
model = r2plus1d_18(weights=R2Plus1D_18_Weights.DEFAULT)

In [49]:
# Replace head

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1, device=DEVICE),
)

# Freeze all layers except the new head
for key, params in model.named_parameters():
    if 'fc' not in key:
        params.requires_grad = False

model.to(device=DEVICE)
print(model) # Perhaps we should also print torchinfo.summary here

VideoResNet(
  (stem): R2Plus1dStem(
    (0): Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (1): BatchNorm3d(45, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
    (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Sequential(
        (0): Conv2Plus1D(
          (0): Conv3d(64, 144, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
          (1): BatchNorm3d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(144, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
        )
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tru

In [50]:
# Taken from https://github.com/pytorch/vision/tree/main/references/video_classification
mean = torch.tensor([0.43216, 0.394666, 0.37645], device=DEVICE)[None, None, None, :]
std = torch.tensor([0.22803, 0.22145, 0.216989], device=DEVICE)[None, None, None, :]

# Initialize train and validation datasets
# TODO videodataset root paths and transforms - this depends on Victor's offline preprocessing 
train_roots = [path.join(SMALL_DATA_FOLDER, str(i)) for i in TRAIN_PARTS]
train_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32),
    transforms.Lambda(lambda video: torch.narrow(video, 0, 0, 95)),
    transforms.Lambda(lambda video: (video - mean) / std),
    # video is in shape (T, H, W, C), we need to permute it to (C, T, H, W)
    transforms.Lambda(lambda video: video.permute(3, 0, 1, 2)),
])

train_ds = VideoTensorDataset(
  original_data_path=BIG_DATA_FOLDER,
  device=DEVICE, 
  tensor_data_paths=train_roots, 
  transform=train_transform
)

validation_roots = [path.join(SMALL_DATA_FOLDER, str(i)) for i in VALIDATION_PARTS]
validation_transform = train_transform

validation_ds = VideoTensorDataset(
  original_data_path=BIG_DATA_FOLDER, 
  device=DEVICE,
  tensor_data_paths=validation_roots, 
  transform=validation_transform
)

# Initialize dataloaders

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_dl = DataLoader(validation_ds, batch_size=BATCH_SIZE, shuffle=False)

In [51]:
batch, labels = next(iter(train_dl))
print(batch.shape, labels.shape)

RuntimeError: The size of tensor a (112) must match the size of tensor b (3) at non-singleton dimension 3

In [32]:
from torchvision.ops import sigmoid_focal_loss
from torchmetrics.classification import BinaryMatthewsCorrCoef

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
def loss_fn(y_pred, y_true):
    return sigmoid_focal_loss(torch.squeeze(y_pred), y_true, gamma=2, alpha=0.161, reduction='mean')

bmcc = BinaryMatthewsCorrCoef().to(device=DEVICE)
def score_fn(y_pred, y_true): 
    # accuracy
    return (y_pred.round() == y_true).float().mean()

    # return bmcc(torch.squeeze(y_pred), y_true)

In [33]:
# Train the model!

model, train_loss, train_score, validation_loss, validation_score, _ = TrainModel(
  model, 
  train_dl, 
  validation_dl, 
  optimizer, 
  NUM_EPOCHS, 
  loss_fn, 
  score_fn
)

Epoch    1 / 10 | Train Loss:  0.050 | Val Loss:  0.064 | Train Score:  0.346 | Val Score:  0.178 | Epoch Time: 1683.84 | <-- Checkpoint! |
Epoch    2 / 10 | Train Loss:  0.048 | Val Loss:  0.071 | Train Score:  0.389 | Val Score:  0.147 | Epoch Time: 1644.67 |
Epoch    3 / 10 | Train Loss:  0.048 | Val Loss:  0.068 | Train Score:  0.389 | Val Score:  0.168 | Epoch Time: 1621.22 |
Epoch    4 / 10 | Train Loss:  0.048 | Val Loss:  0.062 | Train Score:  0.390 | Val Score:  0.291 | Epoch Time: 1647.36 | <-- Checkpoint! |
Epoch    5 / 10 | Train Loss:  0.047 | Val Loss:  0.060 | Train Score:  0.392 | Val Score:  0.268 | Epoch Time: 1663.93 |
Epoch    6 / 10 | Train Loss:  0.047 | Val Loss:  0.066 | Train Score:  0.394 | Val Score:  0.463 | Epoch Time: 1917.84 | <-- Checkpoint! |
Epoch    7 / 10 | Train Loss:  0.048 | Val Loss:  0.069 | Train Score:  0.407 | Val Score:  0.177 | Epoch Time: 1761.96 |
Epoch    8 / 10 | Train Loss:  0.047 | Val Loss:  0.059 | Train Score:  0.405 | Val Score:  

KeyboardInterrupt: 

In [None]:
# TODO plot the results?