In [8]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from numpy.random import RandomState
import os
import cv2
import numpy as np
import pandas as pd
from models.used_models import MesoNet
from utils.data_loader import ImgDataset
from torchvision.transforms import Normalize
import torchvision.transforms as transforms

In [9]:
def main():
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	torch.backends.cudnn.benchmark=True
	
	metadata = pd.read_csv(r'C:\Users\MSI\Downloads\faceforensicsimages\metadata.csv')
	rng = RandomState()
	train_df = metadata.sample(frac=0.8, random_state=rng)
	test_df = metadata.loc[~metadata.index.isin(train_df.index)]
	
	train_data = ImgDataset(metadata) #problem z indeksowaniem
	val_data = ImgDataset(test_df)
	train_loader = DataLoader(train_data, batch_size=8)
	val_loader = DataLoader(val_data, batch_size=8)
	train_dataset_size = len(train_df)  
	val_dataset_size = len(test_df)

	model = MesoNet()
	model = model.to(device)
	optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08)
	criterion = nn.CrossEntropyLoss()
	train(5, model, optimizer,criterion, train_loader, val_loader, device,train_dataset_size, val_dataset_size )

In [10]:
def train(epochs, model, optimizer,criterion, train_loader, val_loader, device,train_dataset_size,val_dataset_size):

    best_model = model.state_dict()
    best_acc = 0.0
    train_loss =0.0
    for epoch in range(epochs):
        model=model.train()
        running_loss = 0.0
        
        for batch, data in enumerate(train_loader):
            
            images = data[0]
            labels = data[1]
            images = images.to(device)
            labels = labels.to(device)
            images = images.view((-1,)+images.shape[2:])
            labels = labels.view((-1,))
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss = loss.data.item()
            train_loss += running_loss
            epoch_loss = train_loss / train_dataset_size 
            
            print('epoch train loss: ',epoch_loss ,'butch nuber: ', batch)

        best_model,best_acc = evaluate(model, val_loader, device, val_dataset_size, best_acc)
        print("epoch: %d"%(epoch))
        torch.save(model.state_dict(), "state_dict/epoch"+str(epoch)+".pkl")
    torch.save(best_model, './best.pkl')



In [11]:
def evaluate(model, val_loader, device,val_dataset_size,best_acc):
    model.eval()
    val_corrects = 0.0
    for batch, data in enumerate(val_loader):
        with torch.no_grad():
            images = data[0]
            labels = data[1]
            images = images.to(device)
            labels = labels.to(device)
            images = images.view((-1,)+images.shape[2:])
            labels = labels.view((-1,))
            
            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)
            val_corrects += torch.sum(preds == labels.data).to(torch.float32)
        epoch_acc = val_corrects / val_dataset_size
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model = model.state_dict()
        print("epoch_acc - ", epoch_acc)
    return best_model,best_acc
    

In [12]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize_transform = Normalize(mean, std)

def load_images(filename, label):
    images_PATH=r'C:\Users\MSI\Downloads\faceforensicsimages'
    frames_path = os.path.join(images_PATH,label,filename[:-4]+'_frames')
    #list = os.listdir(frames_path)
    #frames_count = len(list)
    i,j=0,0
    X = torch.zeros((100, 3, 256, 256))
    while i <100:
        p = os.path.join(frames_path, filename[:-4]+'_img__'+str(j)+'.jpg')
        img = cv2.imread(p)
        if isinstance(img,type(None)):
            j+=1
            continue
        img = torch.tensor(img).float()
        img = img.permute(2,0,1)
        img = normalize_transform(img/256)
        X[i] = img
        i+=1
        j+=1
    y = 0 if label=='original' else 1
    y = torch.tensor([y]*100)
    return X,  y

In [13]:
from torch.utils.data import Dataset

class ImgDataset(Dataset):
    def __init__(self,df):
        self.df = df
        self.filename = df['filename']
        self.labels = df['category']
    def __getitem__(self, index):
        x = self.filename[index]
        y = self.labels[index]
        return load_images(x, y)
    def __len__(self):
        return len(self.df)

In [None]:
if __name__ == '__main__':
    main()