In [5]:
# CELL 1: Imports
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score

In [6]:
# CELL 2: Data Setup & Transforms (Using your existing paths)
dataset_root = r"C:\Users\yozev\OneDrive\Desktop\artFiltered"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

full_dataset = ImageFolder(root=dataset_root, transform=transform)
print("Classes:", full_dataset.classes)

Classes: ['Abstract_Expressionism', 'Art_Nouveau_Modern', 'Baroque', 'Cubism', 'Expressionism', 'Impressionism', 'Naive_Art_Primitivism', 'Northern_Renaissance', 'Post_Impressionism', 'Realism', 'Rococo', 'Romanticism', 'Symbolism']


In [7]:
# CELL 3: Data Splitting
dataset_size = len(full_dataset)
train_size = int(0.75 * dataset_size)
val_size = int(0.10 * dataset_size)
test_size = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [8]:
class HyperNetwork(nn.Module):
    def __init__(self, z_dim, target_hidden_dim):
        super(HyperNetwork, self).__init__()

        self.embedding = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU()
        )

        self.weight_generator = nn.Sequential(
            nn.Linear(256, 2048 * 512),
            nn.Tanh()
        )

        self.bias_generator = nn.Sequential(
            nn.Linear(256, 512),
            nn.Tanh()
        )

    def forward(self, z):
        embedded = self.embedding(z)
        weights = self.weight_generator(embedded)
        weights = weights.view(2048, 512)
        biases = self.bias_generator(embedded)
        return weights, biases

class HyperResNet(nn.Module):
    def __init__(self, num_classes, z_dim=64):
        super(HyperResNet, self).__init__()

        self.resnet = models.resnet50(weights='IMAGENET1K_V2')
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-1])

        # Fixed dimensions for clarity
        self.input_dim = 2048  # ResNet50 output dimension
        self.hidden_dim = 512  # Hidden dimension

        self.hyper_net = HyperNetwork(z_dim, self.hidden_dim)
        self.z_dim = z_dim
        self.z = nn.Parameter(torch.randn(z_dim))

        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        # Get 2048-dimensional features from ResNet
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)  # Batch x 2048

        # Generate weights (2048 x 512) and biases (512)
        weights, biases = self.hyper_net(self.z)

        # Matrix multiplication: (Batch x 2048) @ (2048 x 512) = (Batch x 512)
        x = torch.matmul(features, weights) + biases
        x = torch.relu(x)

        # Final classification
        x = self.classifier(x)
        return x

In [9]:
# CELL 5: Model Setup and Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
num_classes = len(full_dataset.classes)
model = HyperResNet(num_classes=num_classes)  # Make sure to instantiate the model
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()

# Create parameter groups, excluding z from the main parameters
base_params = [p for n, p in model.named_parameters() if n != 'z']
optimizer = optim.Adam([
    {'params': base_params, 'lr': 1e-4},
    {'params': [model.z], 'lr': 1e-3}
])

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=3,
    verbose=True
)

Using device: cuda




In [20]:
try:
    print("Training started...")
    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{epochs}] Training')

        for batch_idx, (images, labels) in enumerate(train_bar):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            train_bar.set_postfix({
                'batch': f'{batch_idx+1}/{len(train_loader)}',
                'train_loss': f'{loss.item():.4f}'
            })

        avg_train_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f'Epoch [{epoch+1}/{epochs}] Validation')
            for images, labels in val_bar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        avg_val_loss = val_loss / len(val_loader)
        accuracy = 100 * correct / total

        # Metrics
        precision = precision_score(all_labels, all_preds, average='weighted')
        recall = recall_score(all_labels, all_preds, average='weighted')

        # Save best model
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'z_vector': model.z.data,
                'best_accuracy': best_accuracy,
            }, os.path.join(dataset_root, 'best_hyper_model.pth'))
            patience_counter = 0
        else:
            patience_counter += 1

        # Print summary
        print(f"\nEpoch [{epoch+1}/{epochs}] Summary:")
        print(f"Training Loss: {avg_train_loss:.4f}")
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Accuracy: {accuracy:.2f}%")
        print(f"Validation Precision: {precision:.4f}")
        print(f"Validation Recall: {recall:.4f}")
        print(f"Best Accuracy: {best_accuracy:.2f}%")
        print('-'*50)

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

        scheduler.step(avg_val_loss)

except KeyboardInterrupt:
    print("\nTraining interrupted. Saving the current model...")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'z_vector': model.z.data,
        'best_accuracy': best_accuracy,
    }, os.path.join(dataset_root, 'interrupted_model.pth'))
    print("Model saved as 'interrupted_model.pth'. You can resume or test it later.")

print("\nTraining completed!")
print(f"Best validation accuracy: {best_accuracy:.2f}%")


Training started...


Epoch [1/20] Training: 100%|██████████| 1518/1518 [2:32:39<00:00,  6.03s/it, batch=1518/1518, train_loss=0.9664]
Epoch [1/20] Validation: 100%|██████████| 203/203 [01:27<00:00,  2.31it/s]



Epoch [1/20] Summary:
Training Loss: 1.3295
Validation Loss: 1.0151
Validation Accuracy: 65.68%
Validation Precision: 0.6692
Validation Recall: 0.6568
Best Accuracy: 65.68%
--------------------------------------------------


Epoch [2/20] Training: 100%|██████████| 1518/1518 [2:31:58<00:00,  6.01s/it, batch=1518/1518, train_loss=0.9015]
Epoch [2/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.38it/s]



Epoch [2/20] Summary:
Training Loss: 0.8020
Validation Loss: 0.9324
Validation Accuracy: 68.57%
Validation Precision: 0.7017
Validation Recall: 0.6857
Best Accuracy: 68.57%
--------------------------------------------------


Epoch [3/20] Training: 100%|██████████| 1518/1518 [2:32:10<00:00,  6.01s/it, batch=1518/1518, train_loss=0.3642]
Epoch [3/20] Validation: 100%|██████████| 203/203 [01:26<00:00,  2.35it/s]



Epoch [3/20] Summary:
Training Loss: 0.5283
Validation Loss: 0.9890
Validation Accuracy: 69.81%
Validation Precision: 0.6987
Validation Recall: 0.6981
Best Accuracy: 69.81%
--------------------------------------------------


Epoch [4/20] Training: 100%|██████████| 1518/1518 [2:32:35<00:00,  6.03s/it, batch=1518/1518, train_loss=0.2916]
Epoch [4/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.38it/s]



Epoch [4/20] Summary:
Training Loss: 0.3431
Validation Loss: 1.1484
Validation Accuracy: 69.78%
Validation Precision: 0.7091
Validation Recall: 0.6978
Best Accuracy: 69.81%
--------------------------------------------------


Epoch [5/20] Training: 100%|██████████| 1518/1518 [2:32:12<00:00,  6.02s/it, batch=1518/1518, train_loss=0.1060]
Epoch [5/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.38it/s]



Epoch [5/20] Summary:
Training Loss: 0.2400
Validation Loss: 1.1959
Validation Accuracy: 70.22%
Validation Precision: 0.7053
Validation Recall: 0.7022
Best Accuracy: 70.22%
--------------------------------------------------


Epoch [6/20] Training: 100%|██████████| 1518/1518 [2:32:06<00:00,  6.01s/it, batch=1518/1518, train_loss=0.5914]
Epoch [6/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.38it/s]



Epoch [6/20] Summary:
Training Loss: 0.1904
Validation Loss: 1.2323
Validation Accuracy: 70.04%
Validation Precision: 0.6997
Validation Recall: 0.7004
Best Accuracy: 70.22%
--------------------------------------------------


Epoch [7/20] Training: 100%|██████████| 1518/1518 [2:32:28<00:00,  6.03s/it, batch=1518/1518, train_loss=0.0219]
Epoch [7/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.39it/s]



Epoch [7/20] Summary:
Training Loss: 0.0821
Validation Loss: 1.3284
Validation Accuracy: 71.64%
Validation Precision: 0.7168
Validation Recall: 0.7164
Best Accuracy: 71.64%
--------------------------------------------------


Epoch [8/20] Training: 100%|██████████| 1518/1518 [2:32:29<00:00,  6.03s/it, batch=1518/1518, train_loss=0.0884]
Epoch [8/20] Validation: 100%|██████████| 203/203 [01:24<00:00,  2.39it/s]



Epoch [8/20] Summary:
Training Loss: 0.0441
Validation Loss: 1.4324
Validation Accuracy: 72.06%
Validation Precision: 0.7197
Validation Recall: 0.7206
Best Accuracy: 72.06%
--------------------------------------------------


Epoch [9/20] Training: 100%|██████████| 1518/1518 [2:32:21<00:00,  6.02s/it, batch=1518/1518, train_loss=0.0018]
Epoch [9/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.38it/s]



Epoch [9/20] Summary:
Training Loss: 0.0318
Validation Loss: 1.5184
Validation Accuracy: 72.26%
Validation Precision: 0.7208
Validation Recall: 0.7226
Best Accuracy: 72.26%
--------------------------------------------------


Epoch [10/20] Training: 100%|██████████| 1518/1518 [2:32:17<00:00,  6.02s/it, batch=1518/1518, train_loss=0.0041]
Epoch [10/20] Validation: 100%|██████████| 203/203 [01:25<00:00,  2.37it/s]



Epoch [10/20] Summary:
Training Loss: 0.0286
Validation Loss: 1.5207
Validation Accuracy: 72.39%
Validation Precision: 0.7250
Validation Recall: 0.7239
Best Accuracy: 72.39%
--------------------------------------------------


Epoch [11/20] Training: 100%|██████████| 1518/1518 [2:33:05<00:00,  6.05s/it, batch=1518/1518, train_loss=0.0040]
Epoch [11/20] Validation: 100%|██████████| 203/203 [01:26<00:00,  2.35it/s]



Epoch [11/20] Summary:
Training Loss: 0.0189
Validation Loss: 1.5783
Validation Accuracy: 72.11%
Validation Precision: 0.7220
Validation Recall: 0.7211
Best Accuracy: 72.39%
--------------------------------------------------


Epoch [12/20] Training:   4%|▍         | 61/1518 [06:25<2:33:24,  6.32s/it, batch=61/1518, train_loss=0.0004]


KeyboardInterrupt: 

In [10]:
checkpoint = torch.load(os.path.join(dataset_root, 'best_hyper_model.pth'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

test_loss = 0.0
correct = 0
total = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing'):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
precision = precision_score(all_labels, all_preds, average='weighted')
recall = recall_score(all_labels, all_preds, average='weighted')

print("\nTest Set Results:")
print(f"Test Accuracy: {accuracy:.2f}%")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")

  checkpoint = torch.load(os.path.join(dataset_root, 'best_hyper_model.pth'))
Testing: 100%|██████████| 304/304 [00:29<00:00, 10.34it/s]


Test Set Results:
Test Accuracy: 71.69%
Test Precision: 0.7167
Test Recall: 0.7169



