In [None]:
!git clone https://github.com/MicheleCazzola/artist-identification.git mlvm-project

!cd mlvm-project; git status

Cloning into 'mlvm-project'...
remote: Enumerating objects: 2187, done.[K
remote: Counting objects: 100% (139/139), done.[K
remote: Compressing objects: 100% (90/90), done.[K
remote: Total 2187 (delta 65), reused 113 (delta 48), pack-reused 2048 (from 1)[K
Receiving objects: 100% (2187/2187), 28.98 MiB | 36.64 MiB/s, done.
Resolving deltas: 100% (1184/1184), done.
On branch main
Your branch is up to date with 'origin/main'.

nothing to commit, working tree clean


In [2]:
!cp -r mlvm-project/src .

In [3]:
import os
import torch
import csv
from src.model.network import MultiBranchArtistNetwork
from src.utils.utils import BackboneType
from src.config.config import Config
from torchvision import transforms
from PIL import Image

In [4]:
def preprocess_image(image_path, size, stats):
    #imagenet = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.CenterCrop((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=stats["mean"], std=stats["std"]),
    ])
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)

In [5]:
def predict_image(model, image_tensor, class_names):
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
        top5_prob, top5_catid = torch.topk(probabilities, 5)
        return [(class_names[idx], prob.item()) for idx, prob in zip(top5_catid, top5_prob)]

In [6]:
def evaluate_model(model, test_dir, class_names, input_size, norm_stats, device):
    model.eval()
    image_files = [f for f in os.listdir(test_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    csv_path = '/kaggle/working/predictions.csv'
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['Image Name', 'Class1', 'Class2', 'Class3', 'Class4', 'Class5'])

        total = len(image_files)
        for (step, image_file) in enumerate(image_files):
            image_path = os.path.join(test_dir, image_file)
            image_tensor = preprocess_image(image_path, input_size, norm_stats).to(device)
            predictions = predict_image(model, image_tensor, class_names)
            
            writer.writerow([image_file] + [class_name for class_name, _ in predictions])

            if (step + 1) % 100 == 0:
                print(f"Done step {step+1}/{total}")

    print(f"Predictions saved to {csv_path}")

In [7]:
train_dir = "/kaggle/input/artist-identification/artist_dataset/train"
test_dir = "/kaggle/input/artist-identification/artist_dataset/test"

NUM_CLASSES = 161
INPUT_SIZE = 512
NORM_STATS = {
    "mean": [
        0.47527796030044556,
        0.42012834548950195,
        0.3588443994522095
    ],
    "std": [
        0.2794029116630554,
        0.27445685863494873,
        0.264132022857666
    ]
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
MODEL_PATH = "/kaggle/input/last_regnet_bs16_hog/pytorch/default/1/best_model_29.pth.tar"
hog_config = Config().hog
model = MultiBranchArtistNetwork(num_classes=NUM_CLASSES, stn=BackboneType.REGNET_X_400MF, use_handcrafted=True, hog_params=hog_config).to(DEVICE)
if MODEL_PATH.endswith(".pth"):
    model.load_state_dict(torch.load(MODEL_PATH, weights_only=True))
elif MODEL_PATH.endswith(".pth.tar"):
    model.load_state_dict(torch.load(MODEL_PATH, weights_only=False).get("model_state_dict"))

Downloading: "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth" to /root/.cache/torch/hub/checkpoints/regnet_x_400mf-adf1edd5.pth
100%|██████████| 21.3M/21.3M [00:00<00:00, 110MB/s] 


In [11]:
CLASS_NAMES = [
    'adriaen-van-ostade', 'agnolo-bronzino', 'albrecht-altdorfer', 'albrecht-durer', 'aleksey-savrasov',
    'alexey-venetsianov', 'alfred-stevens', 'anders-zorn', 'andrea-del-sarto', 'andrea-mantegna', 
    'andrei-ryabushkin', 'anthony-van-dyck', 'antoine-pesne', 'antoine-watteau', 'arkhip-kuindzhi', 
    'bartolome-esteban-murillo', 'benjamin-west', 'benozzo-gozzoli', 'bernardo-bellotto', 
    'boris-kustodiev', 'camille-corot', 'camille-pissarro', 'canaletto', 'caravaggio', 'carlo-crivelli',
    'caspar-david-friedrich', 'charles-francois-daubigny', 'cornelis-springer', 'correggio', 
    'dante-gabriel-rossetti', 'david-teniers-the-younger', 'diego-velazquez', 'dmitry-levitzky', 
    'domenico-ghirlandaio', 'edouard-manet', 'edward-burne-jones', 'edward-hopper', 'edwin-henry-landseer',
    'el-greco', 'esaias-van-de-velde', 'eugene-boudin', 'eugene-delacroix', 'filippo-lippi', 'fra-angelico',
    'francesco-guardi', 'francisco-de-zurbaran', 'francisco-goya', 'frans-hals', 'frans-snyders', 
    'fyodor-bronnikov', 'fyodor-vasilyev', 'george-morland', 'george-stubbs', 'gerard-david', 'gerrit-dou', 
    'gian-lorenzo-bernini', 'giorgio-vasari', 'giovanni-battista-tiepolo', 'giovanni-bellini', 'giovanni-boldini',
    'giovanni-domenico-tiepolo', 'guido-reni', 'gustave-courbet', 'gustave-dore', 'hans-holbein-the-younger',
    'hans-memling', 'henri-fantin-latour', 'henry-raeburn', 'hieronymus-bosch', 'ilya-repin', 'isaac-levitan',
    'ivan-aivazovsky', 'ivan-kramskoy', 'ivan-shishkin', 'ivan-vladimirov', 'jacob-jordaens', 'jacopo-pontormo', 
    'james-mcneill-whistler', 'james-tissot', 'jan-matejko', 'jan-steen', 'jan-van-eyck', 'jean-baptiste-simeon-chardin',
    'jean-fouquet', 'jean-francois-millet', 'jean-honore-fragonard', 'johan-hendrik-weissenbruch', 
    'john-atkinson-grimshaw', 'john-constable', 'john-crome', 'john-everett-millais', 'john-french-sloan', 
    'john-hoppner', 'john-singer-sargent', 'john-william-waterhouse', 'joseph-wright', 'joshua-reynolds', 
    'julius-leblanc-stewart', 'karl-bodmer', 'karl-bryullov', 'konstantin-makovsky', 'leonardo-da-vinci', 
    'lev-lagorio', 'lorenzo-lotto', 'louise-elisabeth-vigee-le-brun', 'luca-signorelli', 'mabuse', 
    'maerten-van-heemskerck', 'martin-schongauer', 'martiros-saryan', 'maurice-quentin-de-la-tour', 
    'michelangelo', 'n.c.-wyeth', 'nicholas-roerich', 'nikolai-ge', 'nikolay-bogdanov-belsky', 
    'odilon-redon', 'orest-kiprensky', 'paolo-uccello', 'paolo-veronese', 'paul-cezanne', 'pavel-svinyin', 
    'peter-paul-rubens', 'piero-della-francesca', 'pieter-bruegel-the-elder', 'pieter-de-hooch', 
    'pietro-longhi', 'pietro-perugino', 'pyotr-konchalovsky', 'raphael', 'rembrandt', 'rogier-van-der-weyden',
    'rudolf-von-alt', 'salvador-dali', 'sandro-botticelli', 'sir-lawrence-alma-tadema', 'taras-shevchenko',
    'theodore-gericault', 'theodore-rousseau', 'thomas-cole', 'thomas-eakins', 'thomas-gainsborough', 
    'tintoretto', 'titian', 'valentin-serov', 'vasily-perov', 'vasily-polenov', 'vasily-surikov', 'vasily-tropinin', 
    'vasily-vereshchagin', 'viktor-vasnetsov', 'vincent-van-gogh', 'vittore-carpaccio', 'vladimir-borovikovsky', 
    'vladimir-makovsky', 'volodymyr-orlovsky', 'william-adolphe-bouguereau', 'william-hogarth', 
    'william-shayer', 'william-turner', 'winslow-homer'
]

In [12]:
class_names = sorted(os.listdir(train_dir))
assert class_names == CLASS_NAMES, f"Something wrong in categories"
assert len(class_names) == NUM_CLASSES, f"Categories mismatch, expected {NUM_CLASSES}, found {len(class_names)}"

In [13]:
evaluate_model(model, test_dir, class_names, INPUT_SIZE, NORM_STATS, DEVICE)

Done step 100/3960
Done step 200/3960
Done step 300/3960
Done step 400/3960
Done step 500/3960
Done step 600/3960
Done step 700/3960
Done step 800/3960
Done step 900/3960
Done step 1000/3960
Done step 1100/3960
Done step 1200/3960
Done step 1300/3960
Done step 1400/3960
Done step 1500/3960
Done step 1600/3960
Done step 1700/3960
Done step 1800/3960
Done step 1900/3960
Done step 2000/3960
Done step 2100/3960
Done step 2200/3960
Done step 2300/3960
Done step 2400/3960
Done step 2500/3960
Done step 2600/3960
Done step 2700/3960
Done step 2800/3960
Done step 2900/3960
Done step 3000/3960
Done step 3100/3960
Done step 3200/3960
Done step 3300/3960
Done step 3400/3960
Done step 3500/3960
Done step 3600/3960
Done step 3700/3960
Done step 3800/3960
Done step 3900/3960
Predictions saved to /kaggle/working/predictions.csv


In [18]:
!ls -la /kaggle/working

total 392
drwxr-xr-x  5 root root   4096 Jan 22 22:14 .
drwxr-xr-x  5 root root   4096 Jan 22 22:07 ..
drwxr-xr-x 10 root root   4096 Jan 22 22:08 mlvm-project
-rw-r--r--  1 root root 374840 Jan 22 22:18 predictions.csv
drwxr-xr-x  9 root root   4096 Jan 22 22:08 src
drwxr-xr-x  2 root root   4096 Jan 22 22:08 .virtual_documents
