In [None]:
!pip install torch torchvision timm matplotlib tqdm pillow numpy

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm
import pickle
import random
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

In [2]:
class artStylesDataset(Dataset):
    def __init__(self,data_dir,transform=None):
        self.data = ImageFolder(data_dir, transform=transform)

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        return self.data[idx]

    @property
    def classes(self):
        return self.data.classes

In [3]:
class ArtsClassifier(nn.Module):
    def __init__(self, num_classes=6):
        super(ArtsClassifier,self).__init__()
        self.base_model = timm.create_model("efficientnet_b0", pretrained=True)
        self.features = nn.Sequential(*list(self.base_model.children())[:-1])
        
        enet_out_size = 1280
        self.classifier = nn.Linear(enet_out_size,num_classes)
        

    def forward(self,x):
        x = self.features(x)
        output = self.classifier(x)
        return output

In [4]:
transform = transforms.Compose([
    transforms.Resize((464,300)),
    transforms.ToTensor(),
])

num_epoch = 15 #Should work fine between 5 to 20
model_path = "./Results/artPythonModel-" + str(num_epoch) + "-epochs.pkl"
plot_path = "./Results/artPythonPlot-" + str(num_epoch) + "-epochs.png"

trainDataset = artStylesDataset("./Data", transform=transform)
validDataset = artStylesDataset("./Valid", transform=transform)

trainLoader = DataLoader(trainDataset,batch_size=16,shuffle=True)
validLoader = DataLoader(validDataset,batch_size=16,shuffle=True)

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = ArtsClassifier(num_classes=6)
model.to(device)

try:
    with open(model_path, 'rb') as file:  
        model = pickle.load(file)
except:
    pass

In [6]:
trainLoss,valLoss = [], []

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adamax(model.parameters(),lr=0.002)

print("Training session start")
for epoch in range(num_epoch):
    model.train()
    running_loss = 0.0
    for images,labels in tqdm(trainLoader,desc="Training progress"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss=criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * labels.size(0)
    train_loss = running_loss/len(trainLoader.dataset)
    trainLoss.append(train_loss)

    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images,labels in tqdm(validLoader,desc="Validation progress"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * labels.size(0)
    val_loss = running_loss / len(validLoader.dataset)
    valLoss.append(val_loss)
    print(f"Epoch {epoch+1}/{num_epoch} - Train loss: {train_loss}, Validation loss: {val_loss}")

Training skipped - model loaded from file


In [7]:
plt.plot(trainLoss, label="Training loss")
plt.plot(valLoss, label="Validation loss")
plt.legend()
plt.title("Loss over epochs")
saved = plt.gcf()
plt.show()
saved.savefig(plot_path)

Plot not generated - model loaded from file


Saving/Loading Model

In [8]:
with open(model_path, 'wb') as file:  
    pickle.dump(model, file)
    print("Pickle file saved")

END