In [15]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.metrics import precision_score, recall_score, accuracy_score
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

In [16]:
class WikiArtDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.files = []

        for label, cls in enumerate(self.classes):
            class_dir = os.path.join(root_dir, cls)
            for file_name in os.listdir(class_dir):
                file_path = os.path.join(class_dir, file_name)
                if os.path.isfile(file_path):
                    self.files.append((file_path, label))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path, label = self.files[idx]
        try:
            image = Image.open(file_path).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Error loading image {file_path}: {e}")
        if self.transform:
            image = self.transform(image)
        return image, label

In [17]:
class ArtClassifier(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ArtClassifier, self).__init__()

        # First hidden layer
        self.hidden1 = nn.Linear(input_dim, 1024)  # HIDDEN1

        # Second hidden layer
        self.hidden2 = nn.Linear(1024, 512)  # HIDDEN2

        # Third hidden layer
        self.hidden3 = nn.Linear(512, 256)  # HIDDEN3

        # Output layer
        self.linear = nn.Linear(256, output_dim)

    def forward(self, x):
        # Apply layers with ReLU activation
        x = torch.relu(self.hidden1(x))
        x = torch.relu(self.hidden2(x))
        x = torch.relu(self.hidden3(x))
        return self.linear(x) 

In [18]:
if __name__ == "__main__":
    # Setup
    data_path = r"C:\Users\yozev\OneDrive\Desktop\artFiltered"

    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])

    # Load dataset
    dataset = WikiArtDataset(root_dir=data_path, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    # Model setup
    input_dim = 64 * 64 * 3  # Input dimensions for 64x64 RGB images
    output_dim = len(dataset.classes)  # Number of art styles
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on: {device}\n")

    # Initialize model, loss, and optimizer
    model = ArtClassifier(input_dim, output_dim).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    epochs = 20
    train_losses = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        pbar = tqdm(enumerate(train_loader), total=len(train_loader),
                   desc=f'Epoch [{epoch+1}/{epochs}]')

        for batch_idx, (images, labels) in pbar:
            images = images.view(images.size(0), -1).to(device) 
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            avg_loss = running_loss / (batch_idx + 1)

            # Update progress bar
            pbar.set_postfix({'loss': f'{avg_loss:.4f}'})

        train_losses.append(avg_loss)
        print(f'Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.4f}')

    # Plot training loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.title('Training Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('training_loss.png')
    plt.close()

    # Evaluation
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.view(images.size(0), -1).to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')

    print('\nFinal Results:')
    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')

Running on: cuda



Epoch [1/20]: 100%|██████████| 810/810 [03:23<00:00,  3.98it/s, loss=2.4596]


Epoch [1/20] Loss: 2.4596


Epoch [2/20]: 100%|██████████| 810/810 [02:52<00:00,  4.70it/s, loss=2.3719]


Epoch [2/20] Loss: 2.3719


Epoch [3/20]: 100%|██████████| 810/810 [02:51<00:00,  4.73it/s, loss=2.3248]


Epoch [3/20] Loss: 2.3248


Epoch [4/20]: 100%|██████████| 810/810 [02:51<00:00,  4.72it/s, loss=2.2912]


Epoch [4/20] Loss: 2.2912


Epoch [5/20]: 100%|██████████| 810/810 [02:50<00:00,  4.74it/s, loss=2.2675]


Epoch [5/20] Loss: 2.2675


Epoch [6/20]: 100%|██████████| 810/810 [02:50<00:00,  4.74it/s, loss=2.2501]


Epoch [6/20] Loss: 2.2501


Epoch [7/20]: 100%|██████████| 810/810 [02:52<00:00,  4.70it/s, loss=2.2288]


Epoch [7/20] Loss: 2.2288


Epoch [8/20]: 100%|██████████| 810/810 [02:53<00:00,  4.67it/s, loss=2.2111]


Epoch [8/20] Loss: 2.2111


Epoch [9/20]: 100%|██████████| 810/810 [02:52<00:00,  4.69it/s, loss=2.1939]


Epoch [9/20] Loss: 2.1939


Epoch [10/20]: 100%|██████████| 810/810 [03:54<00:00,  3.45it/s, loss=2.1793]


Epoch [10/20] Loss: 2.1793


Epoch [11/20]: 100%|██████████| 810/810 [03:58<00:00,  3.40it/s, loss=2.1680]


Epoch [11/20] Loss: 2.1680


Epoch [12/20]: 100%|██████████| 810/810 [04:26<00:00,  3.03it/s, loss=2.1532]


Epoch [12/20] Loss: 2.1532


Epoch [13/20]: 100%|██████████| 810/810 [04:26<00:00,  3.04it/s, loss=2.1402]


Epoch [13/20] Loss: 2.1402


Epoch [14/20]: 100%|██████████| 810/810 [04:16<00:00,  3.15it/s, loss=2.1236]


Epoch [14/20] Loss: 2.1236


Epoch [15/20]: 100%|██████████| 810/810 [03:01<00:00,  4.46it/s, loss=2.1145]


Epoch [15/20] Loss: 2.1145


Epoch [16/20]: 100%|██████████| 810/810 [02:51<00:00,  4.71it/s, loss=2.0992]


Epoch [16/20] Loss: 2.0992


Epoch [17/20]: 100%|██████████| 810/810 [02:51<00:00,  4.73it/s, loss=2.0893]


Epoch [17/20] Loss: 2.0893


Epoch [18/20]: 100%|██████████| 810/810 [02:52<00:00,  4.69it/s, loss=2.0787]


Epoch [18/20] Loss: 2.0787


Epoch [19/20]: 100%|██████████| 810/810 [03:47<00:00,  3.57it/s, loss=2.0657]


Epoch [19/20] Loss: 2.0657


Epoch [20/20]: 100%|██████████| 810/810 [04:28<00:00,  3.02it/s, loss=2.0514]


Epoch [20/20] Loss: 2.0514

Final Results:
Accuracy: 0.2563
Precision: 0.2631
Recall: 0.2563
