In [None]:

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from numpy.random import RandomState
import os
import cv2
import numpy as np
import pandas as pd
from models.MesoNet import MesoNet
from data_loader import ImgDataset
from torchvision.transforms import Normalize
import torchvision.transforms as transforms
from multiprocessing import Manager
import data_preparation
import importlib
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt


In [None]:
def evaluate(model, val_loader, device, best_acc,criterion,frames):
    model.eval()
    val_loss = 0.0
    val_corrects = 0.0
    best_model = model.state_dict()

    for batch, data in enumerate(val_loader):
        with torch.no_grad():
            images = data[0]
            labels = data[1]
            tensor_shape = torch.Tensor(3*frames*val_loader.batch_size, images[0].shape[2], images[0].shape[3])
            images = torch.cat(images, out=tensor_shape)
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.view((-1,))
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.data.item()
            _, preds = torch.max(outputs.data, 1)
            val_corrects += torch.sum(preds == labels.data).to(torch.float32)
            
    epoch_loss = val_loss / len(val_loader.dataset) * frames
    epoch_acc = val_corrects / len(val_loader.dataset) * frames
        
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model = model.state_dict()
    return best_model,best_acc,epoch_loss
    

In [None]:
def train(epochs, model, optimizer, train_loader, val_loader, device,scheduler,frames=10):
    criterion = nn.CrossEntropyLoss()
    best_model = model.state_dict()
    best_acc = 0.0
    train_loss = 0.0
    train_losses, val_losses, val_accuracy , train_accuracy= [], [], [], []
    for epoch in range(epochs):
        model=model.train()
        train_loss = 0.0
        train_corrects = 0.0
        
        for batch, data in enumerate(train_loader):
            print('butch nuber: ', batch)
            iter_loss = 0.0
            images = data[0]
            labels = data[1]
            tensor_shape = torch.Tensor(3*frames*train_loader.batch_size, images[0].shape[2], images[0].shape[3])
            images = torch.cat(images, out=tensor_shape)
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.view((-1,))  
            optimizer.zero_grad()

            predictions  = model(images)
            loss = criterion(predictions , labels)
            loss.backward()
            optimizer.step()

            iter_loss = loss.data.item()
            
            train_loss += iter_loss
            _, preds = torch.max(predictions.data, 1)
            train_corrects += torch.sum(preds == labels.data).to(torch.float32)
               
        best_model, best_val_acc, val_epoch_loss = evaluate(model, val_loader, device, best_acc,criterion,frames)
        train_epoch_loss = train_loss / len(train_loader.dataset) * frames
        epoch_acc = train_corrects / len(train_loader.dataset) * frames
        train_losses.append(train_epoch_loss) 
        val_losses.append(val_epoch_loss)
        val_accuracy.append(best_val_acc)
        train_accuracy.append(epoch_acc)
        print('epoch train loss: ',val_epoch_loss )
        print("epoch: %d"%(epoch))
        torch.save(model.state_dict(), "state_dict/epoch"+str(epoch)+".pkl")
        scheduler.step() 

    torch.save(best_model, 'state_dict/best.pkl')
    print(val_accuracy)
    return train_losses ,val_losses ,train_accuracy ,val_accuracy

In [None]:
def plots_losses(train_losses,val_losses,train_accuracy, val_accuracy):
    plt.plot(train_losses, label='Training loss')
    plt.plot(val_losses, label='Validation loss')
    plt.legend(frameon=False)
    plt.savefig('loss.png')
    plt.show()

    plt.plot(train_accuracy, label='Training accuracy')
    plt.plot(val_accuracy, label='Validation accuracy')
    plt.legend(frameon=False)
    plt.savefig('accuracy.png')
    plt.show()
    

In [None]:
def main(path=r'C:\Users\artur\Downloads\faceforensics_frames\metadata.csv',frames=10):
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	torch.backends.cudnn.benchmark=True
	
	metadata = pd.read_csv(path)
	rng = RandomState()
	train_df = metadata.sample(frac=0.8, random_state=rng)
	test_df = metadata.loc[~metadata.index.isin(train_df.index)]
	
	train_data = ImgDataset(train_df,emotion='Normal',transform = "Harshil_Albumentations",frames=10)
	val_data = ImgDataset(test_df)
	train_loader = DataLoader(train_data, batch_size=16, shuffle=True) 
	val_loader = DataLoader(val_data, batch_size=16, shuffle=True)
	train_dataset_size = len(train_df)  
	val_dataset_size = len(test_df)

	model = MesoNet()
	model = model.to(device)
	optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) #, eps=1e-08
	scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
	dir_path = os.path.abspath('')
	if not os.path.exists(os.path.join(dir_path,"state_dict")):
		os.mkdir(os.path.join(dir_path,"state_dict"))
	train_losses,val_losses,train_accuracy, val_accuracy = train(8, model, optimizer, train_loader, val_loader, device, scheduler, frames)
	print(train_accuracy,val_accuracy)
	print(train_losses,val_losses)
	plots_losses(train_losses,val_losses,train_accuracy,val_accuracy )
	
if __name__ == '__main__':
    main()

In [None]:
def test(path=r'C:\Users\artur\Downloads\faceforensics_frames\metadata.csv',frames=10):
	model = MesoNet()
	model.load_state_dict(torch.load(r'.\state_dict\best.pkl'))
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	test_corrects = 0
	acc = 0
	metadata = pd.read_csv(path)
	test_dataset = ImgDataset(metadata,frames)
	test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)
	print(len(metadata), "videos")
	with torch.no_grad():
		for batch, data in enumerate(test_loader):
			images = data[0]
			labels = data[1]
			tensor_shape = torch.Tensor(3*frames*test_loader.batch_size, images[0].shape[2], images[0].shape[3])
			images = torch.cat(images, out=tensor_shape)
			images = images.to(device)
			labels = labels.to(device)
			labels = labels.view((-1,))
			outputs = model(images)
			_, preds = torch.max(outputs.data, 1)
			test_corrects += torch.sum(preds == labels.data).to(torch.float32)
		print('Accuracy {:.4f}'.format(torch.sum(preds == labels.data).to(torch.float32)/len(test_loader.dataset) * frames))

test()