In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from MLP_classifier import MultiClassClassifier
from torch.utils.data import DataLoader
from dataset import DeepFakeDatasetFastLoad
import torch.nn as nn
import torch
import sys
sys.path.append("../tools")
from constants import PATH_TO_DATA, SEED
from sklearn.model_selection import train_test_split
from torch.utils.data import random_split

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = 0
model = MultiClassClassifier().cuda(device=device)

lr = 1e-3
batch_size = 64
epochs = 5

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

data = DeepFakeDatasetFastLoad("../../data/df_34000.pt")

rng = torch.Generator().manual_seed(SEED)
train_data, test_data = random_split(data,[0.8,0.2],generator=rng)

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_loader  = DataLoader(test_data,batch_size=len(test_data),shuffle=True)

model.train()

MultiClassClassifier(
  (fc1): Linear(in_features=768, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=18, bias=True)
  (act): ReLU()
)

## Train for multi-class classification

In [5]:
n_epochs = 1000
for epoch in range(1,n_epochs+1):
    for idx, batch in enumerate(train_loader):
        # prediction and loss
        pred = model(batch["features"].cuda(device))
        loss = loss_fn(pred,batch["generator"].cuda(device))

        # backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    loss, current = loss.item(), idx*batch_size + len(batch["features"])
    if epoch%10 == 0 and epoch > 0:
        print(f"loss: {loss:>7f}  [{epoch:>5d}/{n_epochs:>5d}]")

loss: 1.422957  [   10/ 1000]
loss: 0.921121  [   20/ 1000]
loss: 0.947804  [   30/ 1000]
loss: 0.851022  [   40/ 1000]
loss: 0.766952  [   50/ 1000]
loss: 0.386493  [   60/ 1000]
loss: 0.367334  [   70/ 1000]
loss: 0.390761  [   80/ 1000]
loss: 0.317838  [   90/ 1000]
loss: 0.302915  [  100/ 1000]
loss: 0.283196  [  110/ 1000]
loss: 0.301721  [  120/ 1000]
loss: 0.335943  [  130/ 1000]
loss: 0.376663  [  140/ 1000]
loss: 0.344195  [  150/ 1000]
loss: 0.231037  [  160/ 1000]
loss: 0.250204  [  170/ 1000]
loss: 0.272052  [  180/ 1000]
loss: 0.356275  [  190/ 1000]
loss: 0.239789  [  200/ 1000]
loss: 0.289642  [  210/ 1000]
loss: 0.236137  [  220/ 1000]
loss: 0.394808  [  230/ 1000]
loss: 0.319827  [  240/ 1000]
loss: 0.213794  [  250/ 1000]
loss: 0.135559  [  260/ 1000]
loss: 0.190557  [  270/ 1000]
loss: 0.222381  [  280/ 1000]
loss: 0.221478  [  290/ 1000]
loss: 0.280228  [  300/ 1000]
loss: 0.212859  [  310/ 1000]
loss: 0.200182  [  320/ 1000]
loss: 0.151896  [  330/ 1000]
loss: 0.32

## Test for binary classification

In [6]:
import torch.types

with torch.no_grad():
    model.eval()
    for e in test_loader:
        pred = data.class_to_label(torch.argmax(model(e["features"].cuda(device)),dim=1))
        accuracy = torch.mean(torch.eq(pred,e["label"]).float()).item()

print(accuracy)

0.9739705920219421


### Comparison with SVM

In [7]:
from sklearn.svm import LinearSVC
from sklearn.multiclass import OneVsOneClassifier

clf = OneVsOneClassifier(LinearSVC(dual="auto"))

In [8]:
train_loader_all = DataLoader(train_data,batch_size=len(train_data))
for e in train_loader_all:
    X_train = e["features"]
    gen_train = e["generator"]
    label_train = e["label"]
for e in test_loader:
    X_test = e["features"]
    gen_test = e["generator"]
    label_test = e["label"]

clf.fit(X_train, gen_train) # train on multi-class classification

In [9]:
import numpy as np
pred = data.class_to_label(clf.predict(X_test))
np.mean(label_test.numpy() == pred.numpy()) # binary classification performance

0.97

## Test on multi-class classification

### Neural Network

In [10]:
import torch.types

with torch.no_grad():
    model.eval()
    for e in test_loader:
        pred = torch.argmax(model(e["features"].cuda(device)),dim=1)
        accuracy = torch.mean(torch.eq(pred.cpu(),e["generator"]).float()).item()

print(accuracy)

0.8783823251724243


### SVM

In [11]:
clf.fit(X_train,gen_train).score(X_test,gen_test)

0.8730882352941176

# Saving the model

In [16]:
# torch.save(model.state_dict(),"./checkpoints/multiclass_1000epochs_0.08loss.pt")

# Loading the model

In [17]:
model2 = MultiClassClassifier()
model2.load_state_dict(torch.load("./checkpoints/multiclass_1000epochs_0.08loss.pt"))

<All keys matched successfully>

In [30]:
model2.eval().cuda(device)
with torch.no_grad():
    for e in test_loader:
        pred = torch.argmax(model2(e["features"].cuda(device)),dim=1)
        acc = torch.mean(torch.eq(e["generator"].cuda(device),pred).float()).item()
        print(acc)

0.8783823251724243
