In [None]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
# torch.autograd.set_detect_anomaly(True)
import numpy as np
import yaml
from functools import partial
from time import gmtime, strftime, time
from sklearn.metrics import classification_report
# declare batch size for act functionop
params = yaml.safe_load(open('resAE_parameters.yaml'))

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Test this with False
torch.backends.cudnn.benchmark = False
np.random.seed(0) 
device = torch.device("cuda")
batch_size = params['batch_size']

In [None]:
# Basic block

class Basicblock(nn.Module):
    expansion = 1
    def __init__(self, input_planes, planes, stride=1, dim_change=None):
        super(Basicblock,self).__init__()
        # Declare convolutional layers with batch norms
        self.conv1 = nn.Conv2d(input_planes, planes, stride=stride, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, stride=1, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(True)
        self.dim_change = dim_change
        
    def forward(self, x):
        # Save the residual
        res = x
        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)
        output = self.conv2(output)
        output = self.bn2(output)
        
        if self.dim_change is not None:
            # print("res before : ", res.size())
            # print(self.dim_change)
            res = self.dim_change(res)
            # print("res after : ", res.size())
            # print("output size : ", output.size())
        output += res
        output = self.relu(output)
        
        return output
    

In [None]:
# Res Encoder

class ResEncoder(nn.Module):
    def __init__(self, block, num_layers, classes=2):
        super(ResEncoder, self).__init__()
        self.input_planes = 8
        # First layer is same
        self.conv1 = nn.Conv2d(in_channels=3,out_channels=8,kernel_size=3,padding=1, stride=1)
        self.relu = nn.ReLU(True)
        
        # Here comes the blocks
        self.layer1 = self._layer(block, 16, num_layers[0], stride=2)
        self.layer2 = self._layer(block, 32, num_layers[1], stride=2)
        self.layer3 = self._layer(block, 64, num_layers[2], stride=2)
        self.last_conv1 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=2,padding=1)
        self.last_bn1 =nn.BatchNorm2d(128)
        self.last_leaky_relu = nn.LeakyReLU(0.0000001)
        
    def _layer(self, block, planes, num_layers, stride=1):
        dim_change = None
        if stride != 2 or planes != self.input_planes*block.expansion:
            dim_change = nn.Sequential(
                nn.Conv2d(self.input_planes, planes*block.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*block.expansion))
            net_layers = []
            net_layers.append(block(self.input_planes, planes, stride=stride, dim_change=dim_change))
            self.input_planes = planes * block.expansion
            for i in range(1, num_layers):
                net_layers.append(block(self.input_planes, planes))
                self.input_planes = planes * block.expansion
            
            return nn.Sequential(*net_layers)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        
        x = self.layer3(x)
        
        x = self.last_conv1(x)
        x = self.last_bn1(x)
        x = self.last_leaky_relu(x)
        # print("size after encoder:", x.size())
        return x

# Make an encoder with Bottleneck, num_layers with 3
encoder = ResEncoder(Basicblock, [3, 3, 3])

In [None]:
# Res Decoder

class ResDecoder(nn.Module):
    def __init__(self, block, num_layers, classes=2):
        super(ResDecoder, self).__init__()
        self.input_planes = 128
        # Here comes the blocks
        self.upsample = nn.Upsample(scale_factor=2,mode='nearest')
        self.layer1 = self._layer(block, 64, num_layers[0], stride=1)
        self.layer2 = self._layer(block, 32, num_layers[1], stride=1)
        self.layer3 = self._layer(block, 16, num_layers[2], stride=1)
        self.layer4 = self._layer(block, 8, num_layers[2], stride=1)
        self.last_conv1 = nn.Conv2d(in_channels=8,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.tanh_10_2 =nn.Tanh()
        
        
    def _layer(self, block, planes, num_layers, stride=2):
        dim_change = None
        if stride != 2 or planes != self.input_planes*block.expansion:
            # print("self.input_planes:", self.input_planes)
            # print("planes*block.expansion:", planes*block.expansion)
            dim_change = nn.Sequential(
                nn.Conv2d(self.input_planes, planes*block.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*block.expansion))
            net_layers = []
            net_layers.append(block(self.input_planes, planes, stride=stride, dim_change=dim_change))
            self.input_planes = planes * block.expansion
            for i in range(1, num_layers):
                net_layers.append(block(self.input_planes, planes))
                self.input_planes = planes * block.expansion
            
            return nn.Sequential(*net_layers)
    def forward(self, x):
        # print("Size before upsample :", x.size())
        x = self.upsample(x)
        # print("Size after upsample :", x.size())
        x = self.layer1(x)
        x = self.upsample(x)
        x = self.layer2(x)
        x = self.upsample(x)
        x = self.layer3(x)
        x = self.upsample(x)
        x = self.layer4(x)
        x = self.last_conv1(x)
        x = self.tanh_10_2(x)
        return x

# Build a decoder
decoder = ResDecoder(Basicblock, [3, 3, 3])

In [None]:
# ResAutoencoder
class ResAutoencoder(nn.Module):
    def __init__(self):
        super(ResAutoencoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        # hook for the gradients of the activations
        # self.gradients = None
        # self.feature_conv = self.encoder.last_leaky_relu
        
    def forward(self, x, label):        
        x = self.encoder(x)
        act = x.clone()
        dep = x.clone()
        
        # Selection block setting zero values based on label
        # [:64] -> fake data latent space
        # [64:] -> real data latent space
        # 0->fake, 1->real
        A = nn.Parameter(torch.zeros(64,15,15))
        
        for i in range(len(label)):
            # real
            if label[i].item():
                # setting fake latent space into zero
                dep[i, :64] = A
            else:
                dep[i, 64:] = A
                
        x = self.decoder(x)
        
        return x, act
    
    # # hook for the gradients of the activation
    # def activations_hook(self, grad):
    #     self.gradients = grad
    #     
    # # method for the gradient extraction
    # def get_activations_gradient(self):
    #     return self.gradients
    # 
    # # method for the activation exctraction
    # def get_activations(self, x):
    #     return self.features_conv(x)
    
resautoencoder = ResAutoencoder()


In [None]:
model = ResAutoencoder()
model_path = ''
model = model.load_state_dict(torch.load(model_path))
model

In [None]:
model.cuda()

learning_rate = params['learning_rate']
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate,eps=1e-7)

#configuration
num_epochs = params['num_epochs']
criterion1 = nn.L1Loss()

In [None]:
# Dataset generation
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# With image augmentation
train_dataset =torchvision.datasets.ImageFolder(root="D:\labwork\local_deepfake\dataset/train",
                                                transform = transforms.Compose([
                                                    transforms.Resize((240,240)),
                                                    torchvision.transforms.ColorJitter(hue=.05, saturation=.05),
                                                    torchvision.transforms.RandomHorizontalFlip(),
                                                    torchvision.transforms.RandomRotation(20),
                                                    transforms.ToTensor()]))
print("Class labels : ", train_dataset.class_to_idx)

validation_dataset =torchvision.datasets.ImageFolder(root="D:\labwork\local_deepfake\dataset/val",
                                                transform = transforms.Compose([
                                                    transforms.Resize((240,240)),
                                                    transforms.ToTensor()]))

train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                           num_workers=4)
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=4)

print("train_dataset %d" % len(train_dataset))
print("validation set %d" % len( validation_dataset))

In [None]:
# Activation Loss
def act_loss_func(outputs, labels):
    batch_size = outputs.size()[0]
    loss_list = torch.zeros([batch_size])
    loss_list = loss_list.to(device)
    # loss_list.cuda()
    for i in range(batch_size):
        #fake
        total_loss =torch.zeros([1],dtype=torch.float32)
        total_loss = total_loss.to(device)
        # total_loss.cuda()
        #real
        total_loss_1 =torch.zeros([1],dtype=torch.float32)
        total_loss_1 = total_loss_1.to(device)
        # total_loss_1.cuda()
        # real
        if labels[i].item():
            #fake
            for latent_index in range(64):
                temp= torch.sum(torch.abs(outputs[i,latent_index,:,:]))/225
                total_loss = total_loss+temp
            #real
            for latent_index in range(64,128):
                temp= torch.abs(1 -torch.sum(torch.abs(outputs[i,latent_index,:,:]))/14400)  # 15*15*64
                total_loss_1 = total_loss_1+temp
        #fake
        else:
            #fake
            for latent_index in range(64):
                temp= torch.abs(1- torch.sum(torch.abs(outputs[i,latent_index,:,:]))/14400)  # 15*15*64
                total_loss = total_loss+temp
            #real
            for latent_index in range(64,128):
                temp= torch.sum(torch.abs(outputs[i,latent_index,:,:]))/225
                total_loss_1 = total_loss_1+temp
        
        loss_list[i]=total_loss+total_loss_1

        
    return torch.sum(loss_list)


In [1]:
# test
def act_loss_test(outputs):
    batch_size = outputs.size()[0]
    answer = torch.zeros([batch_size,2])
    answer.cuda()
    for i in range(batch_size):
        fake = torch.zeros([1], dtype=torch.float32).to(device)
        real = torch.zeros([1], dtype=torch.float32).to(device)
        
        # fake latent space
        for latent_index in range(64):
            # fake = fake + torch.sum(torch.abs(outputs[i, latent_index]))
            fake = fake + torch.sum(torch.abs(outputs[i, latent_index]))/14400
        # real latent space
        for latent_index in range(64, 128):
            # real = real + torch.sum(torch.abs(outputs[i, latent_index]))
            real = real + torch.sum(torch.abs(outputs[i, latent_index]))/14400

        answer[i][0] = fake.item() / (fake.item() + real.item())
        answer[i][1] = real.item() / (fake.item() + real.item())
           
    return answer


# Base Dataset Training

In [None]:
from torch.autograd import Variable
import torchvision.utils as vutils
from sklearn.metrics import classification_report
import os

target_names = ['real','fake']
loss_val =0
print("Start training")
checkpoint_dir = os.path.join('D:\labwork\local_deep_transfer\source\saved_models/checkpoints_kaggle', strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
for epoch in range(num_epochs):
    run_loss = 0
    run_act_loss = 0
    model.train()
    print("epoch : ", epoch)
    start = time()
    for i, (x,label) in enumerate(train_dataloader):
        init = x
        init = init.to(device)
        x = x.view(x.size(),-1)
        x = x.to(device)
        
        label = label.to(device)
        output,act_data = model(x,label)
        rec_loss = criterion1(output, init)
        act_loss = act_loss_func(act_data, label)
        loss = act_loss+0.1*rec_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    end = time()
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item()))
    print("Took ", end - start)

    model.eval()
    
    pred= []
    labels= []
    correct =0
    total =0
    with torch.no_grad():
        loss = 0
        for _, (x,label) in enumerate(validation_dataloader):
            init = x
            init= init.to(device)
            x = x.view(x.size(),-1)
            x = x.to(device)
            a= label.shape[0]
            temp = torch.rand([a])
            temp = temp.to(device)
            output,act_data = model(x,temp)
            outputs  = act_loss_test(act_data)
            
            rec_loss = criterion1(output, init)
            act_loss = act_loss_func(act_data, label)

            loss += act_loss+0.1*rec_loss  
            _, predicted = torch.max(outputs.data, 1)
            pred += predicted.tolist()
            labels += label.tolist()
            correct += (predicted == label).sum().item()
            
            optimizer.zero_grad()

        print("Validation Loss is %f" % loss)
        temp =correct/len(validation_dataset)
        print('Validation Accuracy %f %%' % temp)
        print(classification_report(labels, pred, target_names=target_names, digits=4))

    model_name ="/TAR_" +str(epoch)+ 'epoch_.pth'
    if not os.path.exists(checkpoint_dir):
        os.makedirs(strftime(checkpoint_dir))
    epoch_checkpoint_dir = checkpoint_dir + model_name
    
    torch.save(model.state_dict(), epoch_checkpoint_dir)


# Test 

In [None]:
# Load the model
model = ResAutoencoder()
model.load_state_dict(torch.load("D:\labwork\local_deep_transfer\source\saved_models\checkpoints/resAE_face2face/resAE_face2face_small_leaky_11epoch_.pth"))
print(model)
model.cuda()
model.eval()

In [None]:
test_path = ''
test =torchvision.datasets.ImageFolder(root=test_path, transform = transforms.Compose([transforms.Resize((240,240)),transforms.ToTensor()]))
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=True)


In [None]:
target_names = ['fake','real']
pred= []
labels= []
correct =0
total =0
with torch.no_grad():
    loss = 0
    for _, (x, label) in enumerate(test_dataloader):
        init = x
        init= init.to(device)
        x = x.view(x.size(),-1)
        x = x.to(device)
        a= label.shape[0]
        temp = torch.rand([a])
        output,act_data = model(x,temp)
        outputs  = act_loss_test(act_data)
        
        rec_loss = criterion1(output, init)
        act_loss = act_loss_func(act_data, label)

        loss += act_loss+0.1*rec_loss  
        _, predicted = torch.max(outputs.data, 1)
        pred += predicted.tolist()
        labels += label.tolist()
        correct += (predicted == label).sum().item()
        
        optimizer.zero_grad()

    print("test Loss is %f" % loss)
    temp =correct/len(pred)
    print('test Accuracy %f %%' % temp)
    print(classification_report(labels, pred, target_names=target_names, digits=4))