In [2]:
import os
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
 
# PATHS
NPDD_DIR = "/home/student/Downloads/copycat_fer_project/npdd_images"
OUTPUT_DIR = "/home/student/Downloads/copycat_fer_project/npdd_sl"
TARGET_MODEL_PATH = "/home/student/Downloads/copycat_fer_project/target_mobilenetv2_3class.pth"

# CONFIG
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CONF_THRESHOLD = 0.80   

CLASS_NAMES = ['happy', 'neutral', 'sad']

print("Using device:", DEVICE)

# LOAD TARGET MODEL
model = models.mobilenet_v2(weights=None)
model.classifier = nn.Sequential(
    nn.Dropout(p=0.3),
    nn.Linear(model.last_channel, 3)
)

model.load_state_dict(torch.load(TARGET_MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()

softmax = nn.Softmax(dim=1)

# TRANSFORM
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225])
])

# CREATE OUTPUT FOLDERS
os.makedirs(OUTPUT_DIR, exist_ok=True)

for cls in CLASS_NAMES:
    os.makedirs(os.path.join(OUTPUT_DIR, cls), exist_ok=True)

# PROCESS IMAGES
saved_count = 0
total_count = 0

for img_name in tqdm(os.listdir(NPDD_DIR)):
    
    img_path = os.path.join(NPDD_DIR, img_name)

    if not img_name.lower().endswith(('.png','.jpg','.jpeg')):
        continue

    total_count += 1

    image = Image.open(img_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        output = model(image)
        probs = softmax(output)
        max_prob, pred = torch.max(probs, 1)

        confidence = max_prob.item()
        predicted_class = CLASS_NAMES[pred.item()]

    # SAVE ONLY IF CONFIDENCE >= 80%
    if confidence >= CONF_THRESHOLD:
        save_path = os.path.join(OUTPUT_DIR, predicted_class, img_name)
        image_to_save = Image.open(img_path)
        image_to_save.save(save_path)
        saved_count += 1

print("\nTotal NPDD Images:", total_count)
print("Images Saved (Confidence >= 80%):", saved_count)
print("NPDD-SL dataset created successfully.")


Using device: cuda


100%|████████████████████████████████| 726671/726671 [1:07:23<00:00, 179.71it/s]


Total NPDD Images: 726671
Images Saved (Confidence >= 80%): 35618
NPDD-SL dataset created successfully.



