In [4]:
import sys
import os
script_dir = os.path.dirname(os.path.realpath('__file__'))
parent_dir = os.path.dirname(script_dir)
sys.path.insert(0, parent_dir)

import functions
import models
import pickle
from sklearn.model_selection import train_test_split
import torch
from torchvision import transforms
from torch.utils.data import Subset, DataLoader

In [5]:
# Use the GPU instead of the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define the label map
labels_map = {
    0: "REAL",
    1: "FAKE"
}

# Compose the transform that will be applied to the data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 324k samples for training
train_ds = functions.BinaryGenImage(img_dir=os.path.join(script_dir, '../../dataset/stable_diffusion_1_4/train'), transform=transform)

# 6k samples for validation and 6k samples for testing
val_test_ds= functions.BinaryGenImage(img_dir=os.path.join(script_dir, '../../dataset/stable_diffusion_1_4/val'), transform=transform)
idx_val, idx_test = train_test_split(
    range(len(val_test_ds)),
    test_size=0.5,  # 50-50 split
    stratify=val_test_ds.img_labels, # Uniform class distribution
    random_state=0
)
validation_ds = Subset(val_test_ds, idx_val)
test_ds = Subset(val_test_ds, idx_test)

cuda


# Untrained model

In [6]:
model_init_function = models.binary_untrained
epochs = 15
batch_size = 64
lr = 0.01
val_batch_size = 64

val_dl = DataLoader(validation_ds, batch_size=val_batch_size, num_workers=4)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

# Initialize the model
model = model_init_function()
model = model.to(device)
model_name = f"{model_init_function.__name__}_batch{batch_size}"
print("\n\n - " + model_name)

# Train the model
train_err, val_err, train_loss = functions.train_network(model, device, lr, epochs, train_dl, val_dl)
functions.plot_training_stats(train_err, val_err, train_loss)
model_name += f"_epochs{len(train_err)}"

# Save the progress
torch.save(model.state_dict(), './weights/' + model_name + '.pth')
stats_dict = {
    'train_err': train_err,
    'val_err': val_err,
    'train_loss': train_loss
}
with open('./stats/' + model_name + '.pkl', 'wb') as f:
    pickle.dump(stats_dict, f)



 - binary_untrained_batch64
22.573685884475708
0.3340001106262207
0.3230009078979492
0.3275582790374756
0.3285694122314453
0.320995569229126
0.32700061798095703
0.3280022144317627
0.32399821281433105
0.32399821281433105
0.3284480571746826
0.33051443099975586
0.32700085639953613
0.32899951934814453
0.32901525497436523


KeyboardInterrupt: 

# Finetuning of the pretrained model

In [None]:
model_init_function = models.binary_finetuning
epochs = 15
batch_size = 64
lr = 0.01
val_batch_size = 64

val_dl = DataLoader(validation_ds, batch_size=val_batch_size, num_workers=4)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

# Initialize the model
model = model_init_function()
model = model.to(device)
model_name = f"{model_init_function.__name__}_batch{batch_size}"
print("\n\n - " + model_name)

# Train the model
train_err, val_err, train_loss = functions.train_network(model, device, lr, epochs, train_dl, val_dl)
functions.plot_training_stats(train_err, val_err, train_loss)
model_name += f"_epochs{len(train_err)}"

# Save the progress
torch.save(model.state_dict(), './weights/' + model_name + '.pth')
stats_dict = {
    'train_err': train_err,
    'val_err': val_err,
    'train_loss': train_loss
}
with open('./stats/' + model_name + '.pkl', 'wb') as f:
    pickle.dump(stats_dict, f)