In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import os

from train_autoencoder import Encoder


In [None]:
# transforms for data preprocessing
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

base_dir = os.getcwd()
data_dir = os.path.join(base_dir, 'split_dataset')
train_dir = os.path.join(data_dir, 'train')

test_realworld_dir  = os.path.join(data_dir, 'test_realworld')
test_studio_dir  = os.path.join(data_dir, 'test_studio')

In [None]:
# choose the best device to run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(in_channels=3) # new encoder object of the Encoder class defined; dame architecture for the saved weights
encoder.load_state_dict(torch.load("encoder.pth", map_location=device))
encoder.to(device)
encoder.eval()

In [None]:
# Extract features (the loop you mentioned)
train_features = []
train_labels = []

with torch.no_grad():
    for data, labels in train_loader:
        data = data.to(device)
        latent = encoder(data)
        latent_flat = latent.view(latent.size(0), -1)
        train_features.append(latent_flat.cpu())
        train_labels.append(labels)

train_features = torch.cat(train_features, dim=0).numpy()
train_labels = torch.cat(train_labels, dim=0).numpy()

# Repeat for test set
test_features = []
test_labels = []

with torch.no_grad():
    for data, labels in test_loader:
        data = data.to(device)
        latent = encoder(data)
        latent_flat = latent.view(latent.size(0), -1)
        test_features.append(latent_flat.cpu())
        test_labels.append(labels)

test_features = torch.cat(test_features, dim=0).numpy()
test_labels = torch.cat(test_labels, dim=0).numpy()

# 7️⃣ Save extracted features
np.save("train_features.npy", train_features)
np.save("train_labels.npy", train_labels)
np.save("test_features.npy", test_features)
np.save("test_labels.npy", test_labels)

print("Feature extraction completed and saved!")