# Imports

In [None]:
!pip install torchview

In [None]:
import os
import io
import numpy as np
import pandas as pd 
from pathlib import Path
import pickle

import torch
from torch import nn
import torch.nn.functional as F

from tqdm.auto import tqdm

import torchinfo
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torchvision
from torchvision import datasets, models
from torchvision.transforms import ToTensor, Compose
from torchvision.transforms.functional import to_pil_image, pil_to_tensor

from sklearn.preprocessing import StandardScaler


import PIL

import matplotlib.pyplot as plt
from matplotlib import colormaps

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
art_dataset_path = Path("/kaggle/input/real-ai-art/Real_AI_SD_LD_Dataset/Real_AI_SD_LD_Dataset")

In [None]:
train_art_dir = art_dataset_path / "train"
valid_art_dir = art_dataset_path / "test" # Change later when validation set is available
test_art_dir = art_dataset_path / "test"

train_art_csv_dir = "/kaggle/input/real-ai-art/train_art_col_feat.csv"
valid_art_csv_dir = "/kaggle/input/real-ai-art/test_art_col_feat.csv" # Change later when validation set is available
test_art_csv_dir = "/kaggle/input/real-ai-art/test_art_col_feat.csv"

In [None]:
art_class_labels = [
    'AI_LD_art_nouveau',
    'AI_LD_baroque',
    'AI_LD_expressionism',
    'AI_LD_impressionism',
    'AI_LD_post_impressionism',
    'AI_LD_realism',
    'AI_LD_renaissance',
    'AI_LD_romanticism',
    'AI_LD_surrealism',
    'AI_LD_ukiyo-e',
    'AI_SD_art_nouveau',
    'AI_SD_baroque',
    'AI_SD_expressionism',
    'AI_SD_impressionism',
    'AI_SD_post_impressionism',
    'AI_SD_realism',
    'AI_SD_renaissance',
    'AI_SD_romanticism',
    'AI_SD_surrealism',
    'AI_SD_ukiyo-e',
    'art_nouveau',
    'baroque',
    'expressionism',
    'impressionism',
    'post_impressionism',
    'realism',
    'renaissance',
    'romanticism',
    'surrealism',
    'ukiyo_e']

# Feature Creation

In [None]:
def get_img_colour_features(image):
    rgb_img_features = np.array(image.resize((1,1))).squeeze()
    hsv_img_features = np.array(image.convert('HSV').resize((1,1))).squeeze()
    cmyk_img_features = np.array(image.convert('CMYK').resize((1,1))).squeeze()[:-1]
    
    return np.concatenate((rgb_img_features, hsv_img_features, cmyk_img_features)) # R G B H S V C M Y

In [None]:
#test_img = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/art_nouveau/achille-beltrame_fly-of-gabriele-dannunzio-over-trieste-1915.jpg")
#test_img = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/romanticism/albert-bierstadt_sentinel-falls-and-cathedral-peaks-in-the-yosemite-valley-1864.jpg")
#test_img = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/realism/abbott-handerson-thayer_azores.jpg")
test_img = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/ukiyo_e/chokosai-eisho_99.jpg")

In [None]:
test_img.filename

In [None]:
img_col_features = get_img_colour_features(test_img)
img_col_features

In [None]:
def generate_col_features(dir_img_dataset):
    count = 0
    img_features = []

    for sub_dir, dirs, img_files in os.walk(dir_img_dataset):
        for img_filename in img_files:
            if img_filename.endswith(".jpg"):
                img_path = str(os.path.join(sub_dir, img_filename))
                img = PIL.Image.open(img_path)
                img_col_features = get_img_colour_features(img)

                img_features.append(
                    np.append(np.array(str(os.path.join(sub_dir, img_filename))),img_col_features)
                )
                count+=1

        print(sub_dir, count)

    return pd.DataFrame(
        data=img_features,
        columns=["Path", "Red","Green","Blue","Hue","Sat","Val","Cyan","Mag","Yel"] 
    )

In [None]:
# # Colour feature generation
# train_col_features = generate_col_features(train_art_dir)
# test_col_features = generate_col_features(test_art_dir)

# train_col_features["style"] = train_col_features.apply(lambda df_row: df_row["Path"].split("/")[-2], axis=1)
# test_col_features["style"] = test_col_features.apply(lambda df_row: df_row["Path"].split("/")[-2], axis=1)
# train_col_features["label"] = train_col_features.apply(lambda df_row: art_class_labels.index(df_row["style"]), axis=1)
# test_col_features["label"] = test_col_features.apply(lambda df_row: art_class_labels.index(df_row["style"]), axis=1)

# # Saving to CSV
# train_col_features.to_csv("train_art_col_feat.csv")
# test_col_features.to_csv("test_art_col_feat.csv")

# raise SystemExit("Stopping!")

In [None]:
train_col_features = pd.read_csv("/kaggle/input/real-ai-art/train_art_col_feat.csv")
test_col_features = pd.read_csv("/kaggle/input/real-ai-art/test_art_col_feat.csv")
train_col_features.drop("Unnamed: 0", inplace=True, axis=1, errors='ignore')
test_col_features.drop("Unnamed: 0", inplace=True, axis=1, errors='ignore')

In [None]:
train_col_features.head()

In [None]:
scaler = StandardScaler()

train_col_features[train_col_features.columns[1:10]] = scaler.fit_transform(train_col_features[train_col_features.columns[1:10]])
test_col_features[test_col_features.columns[1:10]] = scaler.fit_transform(test_col_features[test_col_features.columns[1:10]])

In [None]:
test_col_features.head()

In [None]:
# Code to check any dulplicate files under same name
# count = 0
# for sub_dir, dirs, img_files in os.walk(test_art_dir):
#     for img_filename in img_files:
#         if img_filename.endswith(".jpg"):
#             if len(test_col_features[test_col_features['Path'].str.endswith("/"+img_filename.split("/")[-1])].iloc[:,1:-2].values) != 1:
#                 print(sub_dir, img_filename)
#             count+=1

#     print(sub_dir, count)

# Dataset Creation

In [None]:
# Make a CSV with addictional features 
# Build a class like below to integrate the features with the image and the label
# https://medium.com/bivek-adhikari/creating-custom-datasets-and-dataloaders-with-pytorch-7e9d2f06b660
# https://pytorch.org/vision/stable/_modules/torchvision/datasets/folder.html#ImageFolder
# Inherit the DatasetFolder class and overide the __get__() function
from sklearn.preprocessing import StandardScaler
from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS

class ArtStyleDataset(DatasetFolder):

    def __init__(self, dir_root, csv_file_path, transform=None, 
                 target_transform=None, is_valid_file=None, feature_scaler=None):
        super().__init__(
            dir_root,
            default_loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform
        )
        
        self.dir_root = dir_root
        self.transform = transform
        self.feature_scaler = feature_scaler
        self.img_feat_data = self.get_preprocesseed_features(csv_file_path) # Path Red Green Blue Hue Sat Val Cyan Mag Yel style label

    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        img_feature = self.img_feat_data[index]
        
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, img_feature, target

    
    def get_preprocesseed_features(self, csv_file_path):
        """
        Preprocesses the image features
        """
        print("Preprecessing...")
        
        col_features_tensor_list = []
        col_features = pd.read_csv(csv_file_path)
        col_features.drop("Unnamed: 0", inplace=True, axis=1, errors='ignore')
        
        # Scaling
        if self.feature_scaler is None:
            self.feature_scaler = StandardScaler().fit(col_features[col_features.columns[1:10]])
        col_features[col_features.columns[1:10]] = self.feature_scaler.transform(col_features[col_features.columns[1:10]])
        
        for path, _ in tqdm(self.samples):
            col_features_tensor_list.append(
                torch.Tensor(
                    col_features[col_features['Path'].str.endswith("/"+path.split("/")[-1])].iloc[:,1:-2].values
                )
            )
        
        print("Feature count: ", len(col_features_tensor_list))
        
        return col_features_tensor_list


In [None]:
# TODO: Must be done with the custom dataset to be created
imagenet_weights = models.ConvNeXt_Base_Weights.DEFAULT
preprocess_transforms = imagenet_weights.transforms()
preprocess_transforms

In [None]:
exp_data = ArtStyleDataset( # EXPERIMENT!!!!!!!!!!!!!!!!!
    valid_art_dir,
    "/kaggle/input/real-ai-art/test_art_col_feat.csv",
    transform=preprocess_transforms
)

In [None]:
# with open('./test_col_feature_scaler.pkl','wb') as sc_f:
#     pickle.dump(exp_data.feature_scaler, sc_f)
# # with open('./test_col_feature_scaler.pkl','rb') as sc_f:
# #     sc = pickle.load(sc_f)

In [None]:
print("Means: ", exp_data.feature_scaler.mean_)
print("Var: ", exp_data.feature_scaler.var_)

In [None]:
def show_image(image, img_feature, label):
    print(f"Label: {label}") 
    print(f"Feature: {img_feature}") 
    plt.imshow(image.permute(1,2,0))
    plt.show()

show_image(*exp_data[0])

In [None]:
exp_data_dl = DataLoader(
    exp_data,
    batch_size=1,
    shuffle=True,
    num_workers=os.cpu_count()
)

In [None]:
NUM_WORKERS = os.cpu_count()

def get_art_image_dl(
    train_dir: str,
    train_csv_dir: str,
    valid_dir: str,
    valid_csv_dir: str,
    test_dir: str,
    test_csv_dir: str,
    pin_memory: bool,
    transform: Compose, 
    batch_size: int, 
    num_workers: int=NUM_WORKERS
):
    """
    Dataset Generation
    """

    print("Train split creation...")
    # Creation of DataLoaders from Datasets
    train_data = ArtStyleDataset(
        train_dir,
        train_csv_dir,
        transform=transform
    )
    train_data_dl = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    print("Validation split creation...")
    valid_data = ArtStyleDataset(
        valid_dir,
        valid_csv_dir,
        transform=transform
    )
    valid_data_dl = DataLoader(
        valid_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    print("Test split creation...")
    test_data = ArtStyleDataset(
        test_dir,
        test_csv_dir,
        transform=transform,
        feature_scaler=train_data.feature_scaler
    )
    test_data_dl = DataLoader(
        test_data,
        batch_size=1,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    
    print("Colour feature Means     : ", train_data.feature_scaler.mean_)
    print("Colour feature Variances : ", train_data.feature_scaler.var_)
    
    print("Svaing standard scaler...")
    with open('./col_feature_scaler.pkl','wb') as sc_f:
        pickle.dump(train_data.feature_scaler, sc_f)

    return train_data_dl, valid_data_dl, test_data_dl, train_data.classes, train_data.feature_scaler

In [None]:
BATCH_SIZE=32

In [None]:
train_art_dl, valid_art_dl, test_art_dl, art_style_classes, col_feature_scaler = get_art_image_dl(
    train_dir=train_art_dir, 
    valid_dir=valid_art_dir, 
    test_dir=test_art_dir,
    train_csv_dir=train_art_csv_dir, 
    valid_csv_dir=valid_art_csv_dir, 
    test_csv_dir=test_art_csv_dir,
    pin_memory=True,
    transform=preprocess_transforms,
    batch_size=BATCH_SIZE
)

In [None]:
print(f"Train Size      : {len(train_art_dl)} | Batch Size: {BATCH_SIZE}")
print(f"Validation Size : {len(valid_art_dl)} | Batch Size: {BATCH_SIZE}")
print(f"Test Size       : {len(test_art_dl)}  | Batch Size: {1}")

# Data Exploration

In [None]:
train_art_img_batch, train_col_feat_batch, train_art_labels_batch = next(iter(train_art_dl))
train_art_img_batch.shape, train_col_feat_batch.shape, train_art_labels_batch.shape

In [None]:
img_filename, art_style_label = train_art_img_batch[0], train_art_labels_batch[0]

In [None]:
img_ch, img_h, img_w = img_filename.shape
img_ch, img_h, img_w

In [None]:
art_style_label

In [None]:
art_style_classes

In [None]:
fg = plt.figure(figsize=(9, 4))
rows, cols = 2, 4
for i in range(1, rows * cols + 1):
    rand_art_idx = torch.randint(0, len(train_art_img_batch), size=[1]).item()
    art_img, art_style_label = train_art_img_batch[rand_art_idx], train_art_labels_batch[rand_art_idx]
    fg.add_subplot(rows, cols, i)
    plt.imshow(art_img.squeeze().permute(1, 2, 0), cmap="gray")
    plt.title(art_style_classes[art_style_label])
    plt.axis(False);

# Model Development

Potential improvements
- Colour percentages
- Egdes

In [None]:
from torchvision.models import convnext_base
from torchvision.models import ConvNeXt_Base_Weights
class ArtVisionModel(nn.Module):
    
    def __init__(self, n_classes, n_art_img_features=9):
        super(ArtVisionModel, self).__init__()
        
        # Pretrained ConvNeXt
        self.convnext = convnext_base(weights=ConvNeXt_Base_Weights.DEFAULT)
        
        # Freezing top params of the model
        for train_param in self.convnext.features[:-1].parameters():
            train_param.requires_grad = False
        
        # Features
        self.features = self.convnext.features
        
        self.features_low_CNB = self.convnext.features[:4]
        
        # Features - Low level
        self.features_low = self.convnext.features[4]
        
        # Features - Mid level
        self.features_mid_CNB = self.convnext.features[5]
        
        # Features - High level
        self.features_mid = self.convnext.features[6]
        
        # Features - Top level
        self.features_high_CNB = self.convnext.features[7]
    
        self.mid_feat_pooling = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        
        # CNBlock pooling layer
        self.high_feat_pooling = nn.Sequential(
            self.convnext.avgpool,
            self.convnext.classifier[0],
            nn.Flatten()
        )
        
                
        # Colour features
        self.col_features = nn.Sequential(
            nn.Linear(
                in_features=n_art_img_features,
                out_features=64,
                bias=True
            ),
            nn.Flatten()
        )

        # Classifier portion
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.4), 
            nn.Linear(
                in_features=2560+64,
                out_features=512,
                bias=True
            ),
            nn.GELU(),
            nn.Dropout(p=0.1),
            nn.Linear(
                in_features=512,
                out_features=n_classes,
                bias=True
            )
        )
    
    def forward(self, img, img_col_feat):
        # Features
        x = self.features_low_CNB(img)
        x_low = self.features_low(x)
        
        x = self.features_mid_CNB(x_low)
        x_mid = self.features_mid(x)
        
        x = self.features_high_CNB(x_mid)

        # Pooling
        x_high = self.high_feat_pooling(x)
        x_mid = self.mid_feat_pooling(x_mid)
        x_low = self.mid_feat_pooling(x_low)
        
        # Colour features
        img_col_feat = self.col_features(img_col_feat)

        # Feature concatenation
        x = torch.cat((x_high, x_mid, x_low, img_col_feat), dim=1)
                        
        # Classifier
        x = self.classifier(x)
        
        return x
    

In [None]:
# art_vision_model = torch.load("/kaggle/input/art-model-pytorch/art_brain_model.pt", map_location=device)
# model_scripted = torch.jit.script(art_vision_model)# Export to TorchScript

# model_scripted.save('art_brain_model_scripted.pt') # Save
#torch.save(art_vision_model.state_dict(), "./art_vision_model_state.pt")

In [None]:
art_vision_model = ArtVisionModel(n_classes=len(art_style_classes)).to(device)

In [None]:
art_vision_model.features[-1][-1]

In [None]:
torchinfo.summary(
    model=art_vision_model, 
    input_size=[(32, 3, img_h, img_w), (32, 9)],
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"]
) 

In [None]:
import torchvision
from torchview import draw_graph

model_graph = draw_graph(art_vision_model, input_size=[(32, 3, img_h, img_w), (32, 9)], expand_nested=True)
model_graph.visual_graph

In [None]:
#art_vision_model = nn.DataParallel(art_vision_model)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(art_vision_model.parameters(), lr=0.001)
lr_scheduler = ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=2, 
    threshold=0.0001, cooldown=0, min_lr=1e-06, verbose=True
)

In [None]:
"""
Contains functions for training and testing a PyTorch model.
"""

from typing import Dict, List, Tuple

def train_step_procedure(
    epoch: int,
    model: torch.nn.Module, 
    train_dl: torch.utils.data.DataLoader, 
    loss_fn: torch.nn.Module, 
    optimizer: torch.optim.Optimizer,
    device: torch.device) -> Tuple[float, float]:
    """
    
    """
    model.train()

    train_step_loss, train_step_acc = 0, 0
    
    training_pbar = tqdm(enumerate(train_dl), unit="batch", total=len(train_dl))
    training_pbar.set_description("Epoch %02d" % (epoch+1))

    for batch, (train_image, train_feat, img_label) in training_pbar:
        
        train_image, train_feat, img_label = train_image.to(device), train_feat.to(device), img_label.to(device)

        # Forward propagation
        y_pred = model(train_image, train_feat)

        # Loss calcluation
        loss = loss_fn(y_pred, img_label)
        train_step_loss += loss.item() 
        avg_train_step_loss = train_step_loss/(batch+1)
        training_pbar.set_postfix({'Loss': "%.4f" % avg_train_step_loss})

        # Resetting optimizer
        optimizer.zero_grad()

        # Backward propagation
        loss.backward()

        # Weight optimisation
        optimizer.step()

        # Per batch Accuracy calculation
        y_pred_class = torch.argmax(y_pred, dim=1)
        train_step_acc += (y_pred_class == img_label).sum().item()/len(y_pred)

    # Epoch train loss and accuracy
    train_step_loss /= len(train_dl)
    train_step_acc /= len(train_dl)
    return train_step_loss, train_step_acc

def validation_step_procedure(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float]:
    """
    
    """
    # Evaluation mode activation
    model.eval() 

    # Setup test loss and test accuracy values
    valid_step_loss, valid_step_acc = 0, 0

    # Torch inference context manager for efficient performance
    with torch.inference_mode():
        for batch, (valid_image, valid_feat, valid_label) in enumerate(dataloader):
            valid_image, valid_feat, valid_label = valid_image.to(device), valid_feat.to(device), valid_label.to(device)

            valid_pred = model(valid_image, valid_feat)

            # Batch loss
            loss = loss_fn(valid_pred, valid_label)
            valid_step_loss += loss.item()

            # Batch accuracy
            valid_pred_labels = valid_pred.argmax(dim=1)
            valid_step_acc += ((valid_pred_labels == valid_label).sum().item()/len(valid_pred_labels))

    valid_step_loss = valid_step_loss / len(dataloader)
    valid_step_acc = valid_step_acc / len(dataloader)
    return valid_step_loss, valid_step_acc

def train(
    model: torch.nn.Module, 
    train_dl: torch.utils.data.DataLoader, 
    validation_dl: torch.utils.data.DataLoader, 
    optimizer: torch.optim.Optimizer,
    loss_fn: torch.nn.Module,
    epochs: int,
    lr_scheduler: torch.optim.lr_scheduler,
    device: torch.device
) -> Dict[str, List]:
    """
    
    """
    results = {"train_loss": [],
               "train_acc": [],
               "valid_loss": [],
               "valid_acc": [],
               "lr": []
    }
    
    model.to(device)

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step_procedure(
            epoch=epoch,
            model=model,
            train_dl=train_dl,
            loss_fn=loss_fn,
            optimizer=optimizer,
            device=device
        )
        
        valid_loss, valid_acc = validation_step_procedure(
            model=model,
            dataloader=validation_dl,
            loss_fn=loss_fn,
            device=device
        )
        
                
        # Learning rate adjustments
        lr_scheduler.step(valid_loss)

        print(
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"valid_loss: {valid_loss:.4f} | "
          f"valid_acc: {valid_acc:.4f} | "
          f"lr: {optimizer.param_groups[0]['lr']}"
        )

        # Updating results
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["valid_loss"].append(valid_loss)
        results["valid_acc"].append(valid_acc)
        results["lr"].append(optimizer.param_groups[0]['lr'])

    return results

In [None]:
#torch.randn(3, 5, requires_grad=True), torch.empty(3, dtype=torch.long).random_(5)

In [None]:
epochs = 15

In [None]:
train_restults = train(
    model=art_vision_model, 
    train_dl=train_art_dl, 
    validation_dl=valid_art_dl, 
    optimizer=optimizer,
    loss_fn=loss_fn,
    lr_scheduler=lr_scheduler,
    epochs=epochs,
    device=device
)

In [None]:
# Plot training results

In [None]:
## art_vision_model = torch.load("/kaggle/input/art-model-pytorch/art_brain_model.pt", map_location=device)
## model_scripted = torch.jit.script(art_vision_model)# Export to TorchScript

## model_scripted.save('art_brain_model_scripted.pt') # Save

In [None]:
def get_attribution_pred(pred):
    
    ld_score = np.sum(pred[0:10])
    sd_score = np.sum(pred[10:20])
    real_score = np.sum(pred[20:])
    
    attr_preds = [ld_score, sd_score, real_score]
    
    return attr_preds

In [None]:
def get_attribution_label(sub_label):
    return math.floor(sub_label/10)

In [None]:
preds = []
labels = []

attr_preds = []
attr_labels = []

art_vision_model.eval()
with torch.inference_mode():
    # Loop through DataLoader batches
    for batch, (test_image, test_feat, test_label) in tqdm(
        enumerate(testing_dl), unit="images", total=len(testing_dl)
    ):
        
        # Send data to target device
        test_image, test_feat, test_label = test_image.to(device), test_feat.to(device), test_label.to(device)

        # Batch Inference
        test_pred = art_vision_model(test_image, test_feat)

        # Batch loss
        loss = loss_fn(test_pred, test_label)

        # Batch accuracy
        test_pred_labels = test_pred.argmax(dim=1)
        
        preds.append(test_pred_labels[0].item())
        labels.append(test_label[0].item())
        
        test_pred_np = F.softmax(test_pred.detach(), dim=1).cpu().numpy()[0] # TODO: Try softmax here
        
        attr_preds.append(
            np.argmax(get_attribution_pred(test_pred_np))
        )
        
        attr_labels.append(
            get_attribution_label(test_label[0].item())
        )
        

In [None]:
len(preds), len(labels)

In [None]:
preds[:10], labels[:10]

In [None]:
attr_preds[:10], attr_labels[:10]

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(labels, preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=art_style_classes)

fig, ax = plt.subplots(figsize=(15,15))
disp.plot(ax=ax)
plt.xticks(rotation=90)

In [None]:
gen_models = ["standard_diffusion", "latent_diffusion", "human"]

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(attr_labels, attr_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=gen_models)

disp.plot()
plt.xticks(rotation=90)   

In [None]:
ART_MODEL_PATH = "./art_classifier.pt"
torch.save(art_vision_model.state_dict(), ART_MODEL_PATH)
#torch.save(art_vision_model.state_dict(), "./art_vision_model_state.pt")

# GradCAM

In [None]:
def get_img_colour_features(image, col_feature_scaler):
    rgb_img_features = np.array(image.resize((1,1))).squeeze()
    hsv_img_features = np.array(image.convert('HSV').resize((1,1))).squeeze()
    cmyk_img_features = np.array(image.convert('CMYK').resize((1,1))).squeeze()[:-1]
    
    return torch.Tensor(col_feature_scaler.transform(
        [np.concatenate((rgb_img_features, hsv_img_features, cmyk_img_features))]
    ))

In [None]:
#art_vision_model = torch.load("/kaggle/input/art-model-pytorch/art_classifier.pt", map_location=device).module

In [None]:
def get_model_pred(model, art_img_tensor, art_img_col_features):
    for train_param in art_vision_model.features.parameters(): # TODO: For all layers of just for the features?
        train_param.requires_grad = True
    
    gradients = None
    activations = None

    def hook_backward(module, grad_input, grad_output):
        nonlocal gradients
        gradients = grad_output

    def hook_forward(module, args, output):
        nonlocal activations
        activations = output
        
# art_vision_model.features[-1][-1]       
#
#     CNBlock(
#       (block): Sequential(
#         (0): Conv2d(1024, 1024, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=1024) | Size [1, 1024, 7, 7]
#         (1): Permute()
#         (2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
#         (3): Linear(in_features=1024, out_features=4096, bias=True)
#         (4): GELU(approximate='none')
#         (5): Linear(in_features=4096, out_features=1024, bias=True)
#         (6): Permute()
#       )
#       (stochastic_depth): StochasticDepth(p=0.5, mode=row)
#     )
        
    hook_backward = model.features[-1][-1].block[0].register_full_backward_hook(hook_backward, prepend=False)
    hook_forward = model.features[-1][-1].block[0].register_forward_hook(hook_forward, prepend=False)
    
    model.eval()
    
    preds =  model(art_img_tensor.unsqueeze(0), art_img_col_features)
    pred_index = preds.argmax(dim=1)

    preds[:, pred_index].backward()
    
    hook_backward.remove()
    hook_forward.remove()
    
    for train_param in art_vision_model.features.parameters():
        train_param.requires_grad = False
    
    return pred_index, gradients, activations


In [None]:
def generate_grad_map(gradients, activations):

    avg_pooled_gradients = torch.mean(
        gradients[0], # Size [1, 1024, 7, 7]
        dim=[0, 2, 3]
    )

    # Weighting acitvation features (channels) using its related calculated Gradient
    for i in range(activations.size()[1]):
        activations[:, i, :, :] *= avg_pooled_gradients[i]

    # average the channels of the activations
    heatmap = torch.mean(activations, dim=1).squeeze()
    
    # L2 Normalisation # IMPROVED GRADCAM!!!!! EXPERIMENT MORE
    heatmap = F.normalize(heatmap)
        
    # relu on top of the heatmap
    heatmap = F.sigmoid(heatmap)

    # Min-max normalization of the heatmap
    heatmap = (heatmap - torch.min(heatmap))/(torch.max(heatmap) - torch.min(heatmap))

    return heatmap.detach()

In [None]:
test_art_img, _, test_art_labels = next(iter(valid_art_dl))

In [None]:
plt.matshow(test_art_img[6].permute(1,2,0))
test_art_labels[6]

In [None]:
img_col_features = get_img_colour_features(to_pil_image(test_art_img[6].detach()), col_feature_scaler)
pred_index, gradients, activations = get_model_pred(art_vision_model, test_art_img[6].to(device), img_col_features.to(device))
pred_index

In [None]:
heatmap = generate_grad_map(gradients, activations)
print(heatmap)
plt.matshow(heatmap.cpu())

In [None]:
def predict_image(art_img, model, col_feature_scaler, hm_opacity=0.3):
    art_img = art_img.resize((img_h,img_h), resample=PIL.Image.BICUBIC)
    img_col_features = get_img_colour_features(art_img, col_feature_scaler)
    art_img_tensor = preprocess_transforms(art_img)
        
    pred_index, gradients, activations = get_model_pred(model, art_img_tensor.to(device), img_col_features.to(device))
    heatmap = generate_grad_map(gradients, activations)
        
    hm_overlay = to_pil_image(heatmap.detach().cpu(), mode='F').resize((img_h,img_h), resample=PIL.Image.BICUBIC)

    # Jet Colormap
    col_map = colormaps['YlOrRd']
    hm_overlay = PIL.Image.fromarray(
        (255 * col_map(np.asarray(hm_overlay) ** 2)[:, :, :3]).astype(np.uint8)
    )
    
    super_impossed_img = PIL.Image.blend(art_img, hm_overlay, alpha=hm_opacity)
    
    return pred_index, super_impossed_img

In [None]:
#art_image = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/art_nouveau/achille-beltrame_fly-of-gabriele-dannunzio-over-trieste-1915.jpg")
#art_image = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/romanticism/albert-bierstadt_sentinel-falls-and-cathedral-peaks-in-the-yosemite-valley-1864.jpg")
#art_image = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/realism/abbott-handerson-thayer_azores.jpg")
art_image = PIL.Image.open("/kaggle/input/artbench-10-256px/artbench-10-imagefolder-split/artbench-10-imagefolder-split/test/ukiyo_e/chokosai-eisho_99.jpg")
pred_index, pred_hm = predict_image(art_image, art_vision_model, col_feature_scaler, hm_opacity=0.4)
print(art_style_classes[pred_index])
plt.imshow(pred_hm)

In [None]:
# TODO: Modularise the code. Check Notion
# - Breakdown to classes
# - Saving best model function implementation
# - example: https://debuggercafe.com/saving-and-loading-the-best-model-in-pytorch/