In [1]:
import sys
import os
script_dir = os.path.dirname(os.path.realpath('__file__'))
parent_dir = os.path.dirname(script_dir)
sys.path.insert(0, parent_dir)

import functions
import models
import pickle
from sklearn.model_selection import train_test_split
import torch
from torchvision import transforms
from torch.utils.data import Subset, DataLoader

In [2]:
# Use the GPU instead of the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Define the label maps
binary_labels_map = {
    0: "REAL",
    1: "FAKE"
}
with open(os.path.join(script_dir, '../../dataset/imagenet_classes.txt'), 'r') as file:
    labels = [line.strip() for line in file.readlines()]
multiclass_labels_map = {index: label for index, label in enumerate(labels)}

# Compose the transform that will be applied to the data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 324k samples for training
train_ds = functions.MulticlassGenImage(img_dir=os.path.join(script_dir, '../../dataset/stable_diffusion_1_4/train'), transform=transform)

# 6k samples for validation and 6k samples for testing
val_test_ds= functions.MulticlassGenImage(img_dir=os.path.join(script_dir, '../../dataset/stable_diffusion_1_4/val'), transform=transform)
# Perform a stratified split
stratify_labels = [f"{label['binary']}{label['multiclass']}" for label in val_test_ds.img_labels]
idx_val, idx_test = train_test_split(
    range(len(val_test_ds)),
    test_size=0.5,  # 50-50 split
    stratify=stratify_labels, # Uniform class distribution
    random_state=0
)
validation_ds = Subset(val_test_ds, idx_val)
test_ds = Subset(val_test_ds, idx_test)

cuda


# Finetuning

In [3]:
model_init_function = models.multiclass_finetuning
epochs = 15
batch_size = 64
lr = 0.00001
val_batch_size = 64

val_dl = DataLoader(validation_ds, batch_size=val_batch_size, num_workers=4)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

# Initialize the model
model = model_init_function()
model = model.to(device)
model_name = f"{model_init_function.__name__}_batch{batch_size}"
print("\n\n - " + model_name)

# Train the model
train_errors, val_errors, train_loss = functions.train_network_multioutput(model, device, lr, epochs, train_dl, val_dl)
functions.plot_training_stats_multioutput(train_errors, val_errors, train_loss)
model_name += f"_epochs{len(train_errors['binary'])}"

# Save the progress
torch.save(model.state_dict(), './weights/' + model_name + '.pth')
stats_dict = {
'train_errors': train_errors,
'val_errors': val_errors,
'train_loss': train_loss
}
with open('./stats/' + model_name + '.pkl', 'wb') as f:
    pickle.dump(stats_dict, f)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\uiooo/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:11<00:00, 9.19MB/s]




 - multiclass_finetuning_batch64
Epoch 1: current lr = 1e-05
Train error: Combined=95.99%; Binary=32.02%; multiclass=94.20%; 
Validation error: Combined=96.17%; Binary=32.20%; multiclass=94.22%; 
Loss: 6.617e+00


KeyboardInterrupt: 