In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
from matplotlib import pyplot as plt
from pathlib import Path
import numpy as np
import sys
sys.path.append("../model/")
sys.path.append("../tools/")
from MLP_classifier import MultiClassClassifier
from dataset import TaskA, SimpleDataset
from constants import SEED
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm
  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)
Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 11.02it/s]


In [11]:
device = "cuda:0"
clf = MultiClassClassifier(n_classes=2).to(device)
clf.load_state_dict(torch.load("../model/checkpoints/binary_train_real_fake_2k_fine_tune_meta_test.pt"))
clf.eval()

data = TaskA(load_from_disk=True,path="/data4/saland/data/taskA.pt")
data_csv = pd.read_csv("../../docs/scan.csv").iloc[:3333]

In [12]:
data_csv.head(15)

Unnamed: 0.1,Unnamed: 0,image_name,class
0,0,A_005fbfn6.png,1
1,1,A_0060ug9j.png,1
2,2,A_00el4hwr.png,0
3,3,A_00h6ucsm.png,1
4,4,A_00hzao5p.png,0
5,5,A_00jghend.png,1
6,6,A_00kj5fih.png,1
7,7,A_00ww8c88.png,1
8,8,A_0102obzc.png,1
9,9,A_012n784o.png,0


In [13]:
pred_true = {name: {"predicted": None, "true": None} for name in data.image_name}
correctness = []
sorted_names = sorted(data.image_name)

for name in sorted_names[:3333]:
    true_label = data_csv[data_csv["image_name"] == name]["class"].item()
    with torch.no_grad():
        predicted_label = 1 - torch.argmax(clf(data.features[name].to(device))).item() # for the model 0 is fake and 1 real but this is reversed for scan.csv data
    correctness.append(1 if predicted_label == true_label else 0)

print("accuracy:",sum(correctness)/len(correctness))

accuracy: 0.8898889888988899


In [25]:
n_train = 500
train_names = sorted(data.image_name)[:n_train]
test_names  = sorted(data.image_name)[n_train:3333]

# idx = sorted(enumerate(data.image_name),key=lambda x : x[1])

train_features = []
train_labels = []
for name in tqdm(train_names):
    train_features.append(data.features[name])
    train_labels.append(1 - data_csv[data_csv["image_name"] == name]["class"].item()) # LABELS INT VALUES ARE INVERSED BETWEEN MODEL AND CSV FILE

train_features = torch.cat(train_features,dim=0)
train_labels   = torch.Tensor(train_labels).type(torch.LongTensor)

test_features = []
test_labels   = []
for name in tqdm(test_names):
    test_features.append(data.features[name])
    test_labels.append(1 - data_csv[data_csv["image_name"] == name]["class"].item())

test_features = torch.cat(test_features,dim=0)
test_labels   = torch.Tensor(test_labels).type(torch.LongTensor)

train_data = SimpleDataset(features=train_features,label=train_labels)
test_data  = SimpleDataset(features=test_features,label=test_labels)

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:00<00:00, 4619.47it/s]
100%|██████████| 2833/2833 [00:00<00:00, 4903.00it/s]


## Fine tuning on 100 elements from task A

In [26]:
model_ft = MultiClassClassifier(n_classes=2).to(device)
model_ft.load_state_dict(torch.load("../model/checkpoints/binary_train_real_fake_2k_fine_tune_meta_test.pt"))
model_ft.train()

lr = 1e-3
batch_size = 64
n_epochs = 200

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model_ft.parameters(), lr=lr)
rng = torch.Generator().manual_seed(SEED)

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,generator=rng)

loss_history = []
for epoch in range(1,n_epochs+1):
    for idx, batch in enumerate(train_loader):
        # prediction and loss
        pred = model_ft((batch["features"]).to(device))
        loss = loss_fn(pred,batch["label"].type(torch.LongTensor).to(device))

        # backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    loss, current = loss.item(), idx*batch_size + len(batch["features"])
    if epoch%10 == 0 and epoch > 0:
        loss_history.append(loss)
        print(f"loss: {loss:>7f}  [{epoch:>5d}/{n_epochs:>5d}]")

loss: 0.285333  [   10/  200]
loss: 0.239186  [   20/  200]
loss: 0.278224  [   30/  200]
loss: 0.059232  [   40/  200]
loss: 0.145412  [   50/  200]
loss: 0.182333  [   60/  200]
loss: 0.089602  [   70/  200]
loss: 0.013939  [   80/  200]
loss: 0.070951  [   90/  200]
loss: 0.075738  [  100/  200]
loss: 0.092625  [  110/  200]
loss: 0.173172  [  120/  200]
loss: 0.123947  [  130/  200]
loss: 0.167462  [  140/  200]
loss: 0.011299  [  150/  200]
loss: 0.119154  [  160/  200]
loss: 0.047738  [  170/  200]
loss: 0.021478  [  180/  200]
loss: 0.016776  [  190/  200]
loss: 0.026971  [  200/  200]


## Comparing accuracy between original classifier vs fine-tuned

In [27]:
acc_clf = []
acc_model_ft = []

clf.eval()
model_ft.eval()

test_loader = DataLoader(test_data,batch_size=len(test_data))

with torch.no_grad():
    for e in test_loader:
        acc_clf = clf.get_model_accuracy_binary(e["features"],e["label"],device,binary_model=True)
        acc_model_ft = model_ft.get_model_accuracy_binary(e["features"],e["label"],device,binary_model=True)
        
print("accuracy before fine-tuning:",acc_clf)
print("accuracy after  fine-tuning:",acc_model_ft)

accuracy before fine-tuning: 0.8891634345054626
accuracy after  fine-tuning: 0.9488174915313721
