In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from MLP_classifier import MultiClassClassifier
from torch.utils.data import DataLoader
import sys
sys.path.append(".")
from dataset import DeepFakeDatasetFastLoad, OOD
import torch.nn as nn
import torch
sys.path.append("../tools")
from constants import PATH_TO_DATA, SEED
from sklearn.model_selection import train_test_split
from torch.utils.data import random_split

  from .autonotebook import tqdm as notebook_tqdm


In [35]:
device = 0
model = MultiClassClassifier().cuda(device=device)

lr = 1e-3
batch_size = 64
epochs = 5

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

data = DeepFakeDatasetFastLoad("../../data/df_34000.pt")

rng = torch.Generator().manual_seed(SEED)
train_data, test_data = random_split(data,[0.8,0.2],generator=rng)

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_loader  = DataLoader(test_data,batch_size=len(test_data),shuffle=True)

model.train()

MultiClassClassifier(
  (fc1): Linear(in_features=768, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=18, bias=True)
  (act): ReLU()
)

## Train for multi-class classification

In [36]:
model.set_generators_maps(gen_to_int=data.gen_to_int,int_to_gen=data.int_to_gen)

In [37]:
n_epochs = 1000
for epoch in range(1,n_epochs+1):
    for idx, batch in enumerate(train_loader):
        # prediction and loss
        pred = model(batch["features"].cuda(device))
        loss = loss_fn(pred,batch["generator"].cuda(device))

        # backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    loss, current = loss.item(), idx*batch_size + len(batch["features"])
    if epoch%10 == 0 and epoch > 0:
        print(f"loss: {loss:>7f}  [{epoch:>5d}/{n_epochs:>5d}]")

loss: 1.272747  [   10/ 1000]
loss: 0.822165  [   20/ 1000]
loss: 0.587453  [   30/ 1000]
loss: 0.423171  [   40/ 1000]
loss: 0.597076  [   50/ 1000]
loss: 0.541960  [   60/ 1000]
loss: 0.346116  [   70/ 1000]
loss: 0.275015  [   80/ 1000]
loss: 0.234322  [   90/ 1000]
loss: 0.259733  [  100/ 1000]
loss: 0.411867  [  110/ 1000]
loss: 0.265251  [  120/ 1000]
loss: 0.239375  [  130/ 1000]
loss: 0.202963  [  140/ 1000]
loss: 0.292125  [  150/ 1000]
loss: 0.363535  [  160/ 1000]
loss: 0.182247  [  170/ 1000]
loss: 0.228046  [  180/ 1000]
loss: 0.313161  [  190/ 1000]
loss: 0.166717  [  200/ 1000]
loss: 0.242315  [  210/ 1000]
loss: 0.213310  [  220/ 1000]
loss: 0.363220  [  230/ 1000]
loss: 0.267730  [  240/ 1000]
loss: 0.188220  [  250/ 1000]
loss: 0.250089  [  260/ 1000]
loss: 0.203429  [  270/ 1000]
loss: 0.246269  [  280/ 1000]
loss: 0.124978  [  290/ 1000]
loss: 0.348315  [  300/ 1000]
loss: 0.178646  [  310/ 1000]
loss: 0.098493  [  320/ 1000]
loss: 0.267119  [  330/ 1000]
loss: 0.22

## Test for binary classification

In [38]:
import torch.types

with torch.no_grad():
    model.eval()
    for e in test_loader:
        accuracy = model.get_model_accuracy_binary(features=e["features"],
                                                   true_labels=e["label"],
                                                   device="cuda:" + str(device))
print(accuracy)

0.9729411602020264


### Comparison with SVM

In [8]:
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsOneClassifier

clf = OneVsOneClassifier(LinearSVC(dual="auto"))

In [9]:
train_loader_all = DataLoader(train_data,batch_size=len(train_data))
for e in train_loader_all:
    X_train = e["features"]
    gen_train = e["generator"]
    label_train = e["label"]
for e in test_loader:
    X_test = e["features"]
    gen_test = e["generator"]
    label_test = e["label"]

clf.fit(X_train, gen_train) # train on multi-class classification

In [10]:
import numpy as np
pred = data.class_to_label(clf.predict(X_test))
np.mean(label_test.numpy() == pred.numpy()) # binary classification performance

0.97

## Test on multi-class classification

### Neural Network

In [None]:
import torch.types

with torch.no_grad():
    model.eval()
    for e in test_loader:
        accuracy = model.get_model_accuracy_multiclass(features=e["features"],
                                            true_classes=e["generator"],
                                            device="cuda:"+str(device))
print(accuracy)

### SVM

In [None]:
clf.fit(X_train,gen_train).score(X_test,gen_test)

# Saving the model

In [None]:
# torch.save(model.state_dict(),"./checkpoints/multiclass_1000epochs_0.08loss.pt")

# Loading the model

In [None]:
model2 = MultiClassClassifier()
model2.load_state_dict(torch.load("./checkpoints/multiclass_1000epochs_0.08loss.pt"))

In [None]:
model2.eval().cuda(device)
with torch.no_grad():
    for e in test_loader:
        pred = torch.argmax(model2(e["features"].cuda(device)),dim=1)
        acc = torch.mean(torch.eq(e["generator"].cuda(device),pred).float()).item()
        print(acc)

# Train on data3 test on SB

In [None]:
from utils import load_synthbuster_balanced
X_sb, y_sb = load_synthbuster_balanced("../../data/synthbuster_test",
                                       binary_classification=True,
                                       balance_real_fake=True)

In [None]:
model.get_model_accuracy_binary(torch.Tensor(X_sb).cuda(device),
                                torch.Tensor(y_sb).cuda(device),
                                "cuda:"+str(device))

# Test on OOD

In [6]:
ood = OOD("../../data/ood",load_preprocessed=False,device=device)

Processing real images: 100%|██████████| 100/100 [00:13<00:00,  7.66it/s]
Processing images from Lexica: 100%|██████████| 100/100 [00:13<00:00,  7.53it/s]
Processing images from Ideogram: 100%|██████████| 100/100 [00:03<00:00, 29.78it/s]
Processing images from Leonardo: 100%|██████████| 100/100 [00:06<00:00, 14.56it/s]
Processing images from Copilot: 100%|██████████| 100/100 [00:03<00:00, 30.43it/s]
Processing images from img2img_SD1.5: 100%|██████████| 100/100 [00:06<00:00, 14.71it/s]
Processing images from Photoshop_generativemagnification: 100%|██████████| 100/100 [00:33<00:00,  2.96it/s]
Processing images from Photoshop_generativefill: 100%|██████████| 100/100 [01:44<00:00,  1.05s/it]


In [39]:
loader_test = DataLoader(ood, batch_size=len(ood), shuffle=True)

In [40]:
for e in loader_test:
    model.eval()
    with torch.no_grad():
        for e in loader_test:
            accuracy = model.get_model_accuracy_binary(e["features"],e["label"],"cuda:"+str(device))

print(accuracy)

0.5824999809265137
