In [39]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torchvision.datasets import ImageFolder

In [40]:
images_dir = "..\data\wikiart\wikiart\wikiart-saved\images"

In [41]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_list = []
        for subdir, dirs, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith(".jpg") or file.endswith(".jpeg") or file.endswith(".png"):
                    self.file_list.append(os.path.join(subdir, file))

        self.labels = [0] * len(self.file_list)  # Define the labels as a list of zeros with length equal to the number of files

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        image_path = self.file_list[idx]
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        label = self.labels[idx]  # Get the label corresponding to the image at the given index
        return image, label

In [42]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


In [43]:
wikiartSet = CustomDataset(images_dir, transform=transform)
dataloader = DataLoader(wikiartSet, batch_size=32, shuffle=True)
epochs = 5


In [44]:
model = models.vgg19(pretrained=False)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optim = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# Define the number of epochs
epochs = 5

# Train the model using the DataLoader instance
for epoch in range(epochs):
    for images, labels in dataloader:
        # Forward pass
        outputs = model(images)
        # Compute the loss and perform backpropagation
        loss = criterion(outputs, labels)
        loss.backward()
        # Update the model parameters
        optim.step()
        optim.zero_grad()

# Save the trained model
torch.save(model.state_dict(), "trained_vgg19.pth")
