In [None]:
from torch.utils.data import DataLoader
from Dataset.NSRVideoDataset import NSRVideoDataset
from easydict import EasyDict
import torch
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter
import time
import os
import torchsummary
from tqdm import tqdm
import pandas as pd
import numpy as np
from Data_management.EarlyStopping import EarlyStopping

In [None]:
args = EasyDict({
    'dataset_path' : r"D:\Video-Dataset\2022-NSR", # root directory path
    'split' : (0.8, 0.1), # train/validation, validation/train
    'dataset_type' : ("Train", "Validation", "Test"),
    'batch_size' : 16,
    'epochs' : 300,
    'learning_rate' : 1e-3,
    'model_name' : "Densenet201",
    'desc' : '',
    'is_transform' : True,
    'feature': 'frame_diff_hist',  # diag, hist, frame_diff_hist
    'use_frame_df': True # True, False
})

if args.feature == "hist":
    if args.use_frame_df:
        args.desc = 'histogram_difference_log'
    else:
        args.desc = 'histogram_log'
elif args.feature == "diag":
    if args.use_frame_df:
        args.desc = 'diagnal_difference_log'
    else:
        args.desc = 'diagnal_log'
elif args.feature == "frame_diff_hist":
    args.desc = "frame_diff_hist"

In [None]:
train_dataset = NSRVideoDataset(args.dataset_path, args.split, args.dataset_type[0], args.feature, args.use_frame_df, args.is_transform)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=0, drop_last=False)

validation_dataset = NSRVideoDataset(args.dataset_path, args.split, args.dataset_type[1], args.feature, args.use_frame_df, args.is_transform)
validation_dataloader = DataLoader(validation_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, drop_last=False)

test_dataset = NSRVideoDataset(args.dataset_path, args.split, args.dataset_type[2], args.feature, args.use_frame_df, args.is_transform)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=0, drop_last=False)

In [None]:
from Models.Densenet201 import Densenet201

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Densenet201().to(device, non_blocking=True)
# torchsummary.summary(model, (3, 224, 224))

criterion = torch.nn.BCEWithLogitsLoss().to(device, non_blocking=True)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

In [None]:
EXPERIMENT_DIR = f"./runs/{time.strftime('%Y-%m-%d-%H%M%S')}-{args.model_name}-{args.desc}"
os.makedirs(EXPERIMENT_DIR, exist_ok=True)

checkpoint_path = EXPERIMENT_DIR + "/models"
os.makedirs(checkpoint_path, exist_ok=True)

early_stopping = EarlyStopping(patience = 25, verbose = True)
tensorboard_writer = SummaryWriter(log_dir=EXPERIMENT_DIR)

train_losses, train_accuracys = [], []
val_losses, val_accuracys = [], []
min_validation_loss, save_idx = 1000000.0, 0

for epoch in tqdm(range(args.epochs), total=args.epochs, desc="Epoch progress"):
    train_avg_loss, train_accuracy = 0.0, 0.0
    validation_avg_loss, validation_accuracy = 0.0, 0.0

    total_train_batch = len(train_dataloader)
    total_validation_batch = len(validation_dataloader)

    model.train()
    
    for datas, labels in tqdm(train_dataloader, total=total_train_batch, desc="train progress"):
        datas, labels = datas.float().to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        predict = model(datas)
        loss = criterion(predict, labels)
        loss.backward()
        optimizer.step()
        train_avg_loss += loss.item() * datas.size(0)
        train_accuracy += ((predict > 0.5) == labels).float().sum().item()

    train_avg_loss /= len(train_dataset.data_paths)
    train_accuracy /= len(train_dataset.data_paths)

    train_losses.append(train_avg_loss)
    train_accuracys.append(train_accuracy)
    print("Epoch: ", "%d" % (epoch + 1), "train_loss: ", "{:.9f}".format(train_avg_loss), "train_accuracy: ", train_accuracy)

    with torch.no_grad():
        model.eval()

        for datas, labels in tqdm(validation_dataloader, total=total_validation_batch, desc="validation progress"):
            datas, labels = datas.float().to(device, non_blocking=True), labels.to(device, non_blocking=True)

            predict = model(datas)

            loss = criterion(predict, labels).detach()
            validation_avg_loss += loss.item() * datas.size(0)
            validation_accuracy += ((predict > 0.5) == labels).float().sum().item()

        validation_avg_loss /= len(validation_dataset.data_paths)
        validation_accuracy /= len(validation_dataset.data_paths)

        val_losses.append(validation_avg_loss)
        val_accuracys.append(validation_accuracy)
        print("Epoch: ", "%d" % (epoch + 1), "validation_loss: ", "{:.9f}".format(validation_avg_loss), "validation_accuracy: ", validation_accuracy)
        
        if validation_avg_loss < min_validation_loss:
            checkpint = {
            'state_dict' : model.state_dict(), 
            'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpint, checkpoint_path + f"/checkpoint.pth.tar")

            print(f"{min_validation_loss} -> {validation_avg_loss} decreased validation loss -> saved model-{epoch + 1}")
            min_validation_loss = validation_avg_loss
            save_idx = epoch + 1

        early_stopping(validation_avg_loss, model)
        
        if early_stopping.early_stop:
            break

    tensorboard_writer.add_scalars("Accuracy", {
        "Train" : train_accuracy,
        "Validation" : validation_accuracy
        }, 
        epoch + 1)
    tensorboard_writer.add_scalars("Loss", {
        "Train" : train_avg_loss,
        "Validation" : validation_avg_loss
        }, 
        epoch + 1)

    tensorboard_writer.flush()

train_losses, train_accuracys = np.array(train_losses), np.array(train_accuracys)
val_losses, val_accuracys = np.array(val_losses), np.array(val_accuracys)
train_logs = np.stack([train_losses, train_accuracys, val_losses, val_accuracys], axis=1)

train_log_df = pd.DataFrame(train_logs, columns=["train_loss", "train_accuracy", "validation_loss", "validation_accuracy"])
train_log_df.to_csv(EXPERIMENT_DIR + "/train_log.csv", sep=",")

tensorboard_writer.close()

In [None]:
checkpoint_path = EXPERIMENT_DIR + "/models"

load_path = f"{checkpoint_path}/checkpoint.pth.tar"

def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

load_checkpoint(torch.load(load_path), model, optimizer)

In [None]:
test_avg_loss, test_accuracy = 0.0, 0.0
y_pred, y_true = [], []
total_test_batch = len(test_dataloader)

with torch.no_grad():
    for datas, labels in tqdm(test_dataloader, total=len(test_dataloader), desc="Test Progress"):
        datas, labels = datas.float().to(device, non_blocking=True), labels.to(device, non_blocking=True)
        
        test_predicts = model(datas)
        
        loss = criterion(test_predicts, labels).detach()
        test_avg_loss += loss.item() * datas.size(0)
        test_accuracy += ((test_predicts > 0.5) == labels).float().sum().item()

        labels = labels.cpu().numpy()
        test_predicts = (test_predicts > 0.5).float().cpu().numpy()
        y_pred.extend(test_predicts)
        y_true.extend(labels)

test_loss, test_accuracy = test_avg_loss / len(test_dataset.data_paths), test_accuracy / len(test_dataset.data_paths)
print(test_loss, test_accuracy)

with open(EXPERIMENT_DIR + "/test_log.txt", "w", encoding='utf-8') as file:
    file.write(f"Test Loss : {test_loss}\n")
    file.write(f"Test Accuracy : {test_accuracy}")
    file.close()

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sn
cm = confusion_matrix(y_true, y_pred)
labels = ["Non-Edited", "Edited"]
df_cm = pd.DataFrame(cm, index = [label for label in labels], columns = [label for label in labels])
plt.figure(figsize = (12,7))

plt.title('Predicted Class')

sn.set(font_scale=2)
sn.heatmap(df_cm, annot=True, fmt='d', annot_kws={"weight": "bold"})

plt.ylabel('Actual Class')
# plt.figtext(0.5, 0.01, f'Accuracy: {test_accuracy}', wrap=True, ha='center', fontsize=24)

plt.savefig(EXPERIMENT_DIR + '\confusion_matrix.png')

In [None]:
import shutil

error_indexes = []
for idx, (p, t) in enumerate(zip(y_pred, y_true)):
    if p != t:
        error_indexes.append(idx)

error_list = []
for idx in error_indexes:
    error_list.append(test_dataset.data_paths[idx])

legal_error_destination = EXPERIMENT_DIR + "\error_ag_to_non"
os.makedirs(legal_error_destination, exist_ok=True)
illegal_error_destination = EXPERIMENT_DIR + "\error_non_to_ag"
os.makedirs(illegal_error_destination, exist_ok=True)

for error_path, label in tqdm(error_list, total=len(error_list)):
    file_name = error_path.split("\\")[-1]
    if label[0] == 1:
        shutil.copy2(error_path, f"{legal_error_destination}\{file_name}")
    else:
        shutil.copy2(error_path, f"{illegal_error_destination}\{file_name}")