In [1]:
import torchvision

from fastai.vision.all import *

from pathlib import Path

from PIL import Image

from torchvision.transforms.functional import to_pil_image, to_tensor
from torchvision.transforms import Normalize

import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

import pandas as pd

In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

## Carregando modelo do Fastai Pré Treinado

In [3]:
def fastai_to_pytorch(base_model, path, path_to_model):
    all_data = ImageDataLoaders.from_folder(path=path, # Indicamos aqui o caminho que contém as 2 pastas FAKE e REAL
                                        train='train', # Indicamos aqui o nome da pasta que é de treinamento. Observe que como nosso dataset não possui ainda a divisão, deixamos padrão
                                        valid='valid', # O mesmo vale para a parte de validação
                                        seed=42, # Fornecemos um seed para que a divisão seja sempre a mesma, não importa a execução
                                        bs=1, # Fornecemos o tamanho do lote de imagens
                                        item_tfms=Resize(224), # Indicamos que queremos redimensionar os itens individuais para o tamanho indicado (224x224 no caso)
                                        #batch_tfms=Normalize.from_stats(*imagenet_stats), # Indicamos que queremos normalizar as imagens com os mesmos status de quando foram treinadas 
                                        num_workers=0 # Esse argumento é necessário para funcionar no windows
                                       )
    learn = cnn_learner(dls=all_data, 
                        arch=base_model)

    learn.load("../../" + path_to_model)
    
    return learn.model.eval()

In [4]:
model_1 = fastai_to_pytorch(resnet18, path="./Dummy Dataset", path_to_model="models/resnet18_best_mcc_final")
model_2 = fastai_to_pytorch(resnet18, path="./Dummy Dataset", path_to_model="models/resnet18_best_mcc_F")

In [137]:
class DFDataset(torch.utils.data.Dataset):
    
    def __init__(self, real_groups, 
                 fake_groups, 
                 path="./Faces Dataset Transformer", 
                 path_to_model="models/resnet18_best_mcc_final",
                 device="cuda"):
        
        self.device = device
        self.path = path
        self.groups = real_groups + fake_groups
        self.labels = torch.cat((torch.zeros(len(real_groups)).long(),
                                 torch.ones(len(fake_groups)).long()
                                ), axis=0)
        self.model = self.fastai_to_pytorch(base_model=resnet18, path=path, path_to_model=path_to_model).eval().to(device)
        self.model[1][8] = nn.Identity()
        self.size = 224
        self.retrieve_tensors()
        
    def fastai_to_pytorch(self, base_model, path, path_to_model):
        all_data = ImageDataLoaders.from_folder(path=path,
                                        train='train',
                                        valid='valid',
                                        valid_pct=0.1,
                                        seed=42,
                                        bs=1,
                                        item_tfms=Resize(224),
                                        num_workers=0)
        
        learn = cnn_learner(dls=all_data, 
                            arch=base_model)

        learn.load("../../" + path_to_model)

        return learn.model
        
    def retrieve_tensors(self):
        self.data = torch.tensor([]).to(self.device)

        with torch.no_grad():
            for groups in self.groups:
                out = torch.tensor([]).to(self.device)
                img_zeros = torch.tensor([]).to(self.device)
                for i, image_path in enumerate(groups):
                    if image_path != "__PAD__":
                        out = torch.cat((
                            (Normalize(*imagenet_stats)(to_tensor(Image.open(image_path).resize((self.size, self.size))))).unsqueeze(0).to(self.device),
                            out), axis=0)
                    else:
                        img_zeros = torch.zeros(len(self.groups[i])-i, 512).to(self.device)
                        break
                
                out = torch.cat((self.model(out), img_zeros), axis=0)
                
                self.data = torch.cat((out.unsqueeze(0), self.data), axis=0)
                
        self.data = self.data.detach().cpu()
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        return (self.data[i], self.labels[i])

In [138]:
class DFDataModule(pl.LightningDataModule):
    
    def __init__(self, batch_size=32, path="./Faces Dataset Transformer", path_to_model="models/resnet18_best_mcc_final"):
        super().__init__()
        self.path = path
        self.batch_size = batch_size

    def setup(self, stage):
        
        base_dir = Path(self.path)
        all_real_paths = list((base_dir/"REAL").glob("*jpg"))
        all_fake_paths = list((base_dir/"FAKE").glob("*jpg"))
        
        groups = [[], []]
        max_len = 0
        init_name = None
        for j, groups_path in enumerate([all_real_paths, all_fake_paths]):
            for i, path in enumerate(groups_path):
                path_stem = path.stem
                folder, name, pos = path_stem.split(" ")
                if name != init_name:
                    if i != 0:
                        if len(group) > max_len:
                            max_len = len(group)
                        indexes = np.array([group_path.stem.split(" ")[2].replace(".jpg","") for group_path in group], dtype=np.int8)
                        sort = indexes.argsort()
                        groups[j].append([group[s] for s in sort])
                    init_name = name
                    group = []
                else:
                    group.append(path)

        for group in groups:
            for i, gr in enumerate(group):
                n_append = max_len - len(gr)
                if n_append != 0:
                    group[i] = group[i] + ["__PAD__" for n in range(n_append)]
        
        self.real_groups, self.fake_groups = groups[0], groups[1]
        
        self.dfdataset = DFDataset(self.real_groups, self.fake_groups)
        ratio = 0.7
        self.dfdataset_train, self.dfdataset_val = torch.utils.data.random_split(self.dfdataset, [int(ratio*len(self.dfdataset)), 
                                                       len(self.dfdataset) - int(ratio*len(self.dfdataset))])
        
    def train_dataloader(self):
        train_dl = torch.utils.data.DataLoader(self.dfdataset_train, batch_size=self.batch_size, shuffle=True)
        return train_dl
        
    def val_dataloader(self):
        val_dl = torch.utils.data.DataLoader(self.dfdataset_val, batch_size=self.batch_size, shuffle=True)
        return val_dl

In [139]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class CNNFeatureTransformerWithoutModel(nn.Module):
    
    def __init__(self, d_model, max_len):
        super(CNNFeatureTransformerWithoutModel, self).__init__()
        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=8, dim_feedforward=2048), num_layers=6)
        self.decoder = nn.Linear(d_model, 2)
        
    def forward(self, x):
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.decoder(x)
        return x[:, 0, :] # Classificação
    
class CNNFeatureTransformerWithModel(nn.Module):
    
    def __init__(self, d_model, model, max_len):
        super(CNNFeatureTransformerWithModel, self).__init__()
        self.model = model
        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=8, dim_feedforward=2048), num_layers=6)
        self.decoder = nn.Linear(d_model, 2)
        
    def forward(self, x):
        batch_size = x.shape[0]
        num_samples = x.shape[1]
        
        x = x.reshape(x.shape[0]*x.shape[1], x.shape[2], x.shape[3], x.shape[4])
        x = self.model(x).reshape(batch_size, num_samples, -1)
         
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.decoder(x)
        return x[:, 0, :] # Classificação
    
cnnft_modelWM = CNNFeatureTransformerWithModel(d_model=512, model=model, max_len=2000)
cnnft_model = CNNFeatureTransformerWithoutModel(d_model=512, max_len=2000)

In [178]:
class ConvFusionTransformer(pl.LightningModule):
    
    def __init__(self, d_model, max_len):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=d_model, nhead=8, dim_feedforward=2048), num_layers=6)
        self.decoder = nn.Linear(d_model, 2)
        self.accuracy = pl.metrics.Accuracy()
        
    def forward(self, x): # No lightning são as ações de inferência
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.decoder(x)
        return x[:, 0, :]
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.decoder(x)[:, 0, :]
        loss = F.cross_entropy(x, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.decoder(x)[:, 0, :]
        loss = F.cross_entropy(x, y)
        self.log('val_loss', loss)
        self.accuracy(F.softmax(x, dim=1), y)
        self.log('valid_acc', self.accuracy, on_step=True, on_epoch=True)
        return loss

## Treinamento

Vamos constatar a eficácia.

In [179]:
current_model = ConvFusionTransformer(d_model=512, max_len=200)
dfmodel = DFDataModule()
trainer = pl.Trainer(gpus='0', callbacks=[EarlyStopping(monitor="val_loss", patience=5)])

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [180]:
trainer.fit(current_model, datamodule=dfmodel)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type               | Params
---------------------------------------------------
0 | pos_encoder | PositionalEncoding | 0     
1 | transformer | TransformerEncoder | 18.9 M
2 | decoder     | Linear             | 1.0 K 
3 | accuracy    | Accuracy           | 0     
---------------------------------------------------
18.9 M    Trainable params
0         Non-trainable params
18.9 M    Total params
75.661    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

1

In [69]:
def get_accuracy(cnnft_model, wm):
    cnnft_model.to(device).eval()
    with torch.no_grad():
        labels = []
        predictions = []
        if wm:
            for features, targets in test_dlWM:
                features, targets = features.to(device), targets
                out = cnnft_model(features)

                predictions += out.max(dim=1)[1].tolist()

                labels += targets.tolist()
        else:
            for features, targets in test_dl:
                features, targets = features.to(device), targets
                out = cnnft_modelWM(features)

                predictions += out.max(dim=1)[1].tolist()

                labels += targets.tolist()
            
    right = np.sum(np.array(labels) == np.array(predictions))
    total = len(labels)
    acc_for_zero = []
    acc_for_one = []
    for i, j in zip(labels, predictions):
        if i == 1 and i == j:
            acc_for_one.append(True)
        elif i == 1 and i != j:
            acc_for_one.append(False)
        elif i == 0 and i == j:
            acc_for_zero.append(True)
        else:
            acc_for_zero.append(False)
    
    acc_for_zero = np.array(acc_for_zero)
    acc_for_one = np.array(acc_for_one)
    print(f"Accuracy: {right/total*100: .2f}%\t| Accuracy for FAKE: {acc_for_zero.sum()/len(acc_for_zero)*100:.2f}%\t| Accuracy for REAL: {acc_for_one.sum()/len(acc_for_one)*100:.2f}%")
    return np.array(labels), np.array(predictions)

In [70]:
labels2, preds2 = get_accuracy(cnnft_modelWM, wm=False)

Accuracy:  68.99%	| Accuracy for FAKE: 56.58%	| Accuracy for REAL: 80.49%
