# Libraries / Dependencies

In [3]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
import torch
import time
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pathlib import Path
from tqdm import tqdm
import shutil
from sklearn.model_selection import train_test_split
import pandas as pd
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

# Class Artstyle
 Using Dataset

In [4]:
class ArtStyleDataset(Dataset):
    def __init__(self, data_dir, transform=None, mode='train', label_to_idx=None):
        self.data_dir = Path(data_dir) / mode
        self.transform = transform
        self.mode = mode
        
        self.images = []
        self.labels = []
        self.label_to_idx = label_to_idx  # przekazana z zewnƒÖtrz

        if self.label_to_idx is None:
            raise ValueError("label_to_idx musi byƒá przekazane z zewnƒÖtrz, aby zapewniƒá sp√≥jno≈õƒá klas.")

        for style_dir in self.data_dir.iterdir():
            if style_dir.is_dir():
                style_name = style_dir.name

                if style_name not in self.label_to_idx:
                    print(f"‚ö†Ô∏è Uwaga: styl '{style_name}' nie znajduje siƒô w label_to_idx. Pomijam.")
                    continue

                label_idx = self.label_to_idx[style_name]
                for img_path in style_dir.glob('*.jpg'):
                    self.images.append(str(img_path))
                    self.labels.append(label_idx)

        self.idx_to_label = {v: k for k, v in self.label_to_idx.items()}
        print(f"[{mode}] Znaleziono {len(self.images)} obraz√≥w w {len(self.label_to_idx)} klasach.")

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        try:
            img = cv2.imread(img_path)
            if img is None:
                raise ValueError("Nie mo≈ºna wczytaƒá obrazu")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if self.transform:
                img = self.transform(img)
        except Exception as e:
            print(f"B≈ÇƒÖd przetwarzania {img_path}: {e}")
            img = torch.zeros(3, 224, 224)
            label = 0

        return img, label, img_path



# Image Processing

In [5]:
def preprocess_images(input_dir, output_dir, target_size=(512, 512)):
    """
    Przetwarza obrazy z folderu wej≈õciowego i zapisuje je w folderze wyj≈õciowym
    z zachowaniem struktury katalog√≥w (stylu artystycznego).
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    
    output_path.mkdir(parents=True, exist_ok=True)
    
    style_dirs = [d for d in input_path.iterdir() if d.is_dir()]
    
    # foldery podfolderow sama nie wiem juz ile tych folderow
    for style_dir in style_dirs:
        style_name = style_dir.name
        output_style_dir = output_path / style_name
        output_style_dir.mkdir(exist_ok=True)
        
        sub_dirs = [d for d in style_dir.iterdir() if d.is_dir()]
        
        if sub_dirs:
            for sub_dir in sub_dirs:
                image_files = list(sub_dir.glob('*.jpg')) 
                print(f"Przetwarzanie {len(image_files)} obraz√≥w w stylu {style_name}")
                
                for img_path in tqdm(image_files, desc=f"Styl: {style_name}"):
                    process_image(img_path, output_style_dir, target_size)
        else:
            image_files = list(style_dir.glob('*.jpg'))
            print(f"Przetwarzanie {len(image_files)} obraz√≥w w stylu {style_name}")
            
            for img_path in tqdm(image_files, desc=f"Styl: {style_name}"):
                process_image(img_path, output_style_dir, target_size)
    
    print("Zako≈Ñczono preprocessing obraz√≥w!")

def process_image(img_path, output_dir, target_size):
    """
    Przetwarza pojedynczy obraz.
    """
    try:
        img = cv2.imread(str(img_path))
        if img is None:
            print(f"Nie uda≈Ço siƒô wczytaƒá: {img_path}")
            return
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  #BGR na RGB
        
        h, w = img.shape[:2]
        aspect_ratio = w / h
        
        if aspect_ratio > 3 or aspect_ratio < 1/3: 
            if w > h:
                new_w = int(h * 1.5)  
                start_x = int((w - new_w) / 2)
                img = img[:, start_x:start_x+new_w]
            else:
                new_h = int(w * 1.5) 
                start_y = int((h - new_h) / 2)
                img = img[start_y:start_y+new_h, :]
        

        h, w = img.shape[:2]
        if w > h:
            new_w = min(w, target_size[0])
            new_h = int(h * (new_w / w))
        else:
            new_h = min(h, target_size[1])
            new_w = int(w * (new_h / h))
        
        img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
        
        # padding
        padded_img = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
        
        start_x = (target_size[0] - new_w) // 2
        start_y = (target_size[1] - new_h) // 2
        padded_img[start_y:start_y+new_h, start_x:start_x+new_w] = img_resized
        
        output_file = output_dir / img_path.name
        cv2.imwrite(str(output_file), cv2.cvtColor(padded_img, cv2.COLOR_RGB2BGR))
        
    except Exception as e:
        print(f"B≈ÇƒÖd podczas przetwarzania {img_path}: {e}")


# Datasplitting
train, validation, test

In [6]:
def split_data_with_balancing(processed_dir, output_dir, 
                             train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, 
                             min_class_size=2000, max_class_size=3000,
                             random_state=42):
   
    assert train_ratio + val_ratio + test_ratio == 1.0, "Proporcje podzia≈Çu muszƒÖ sumowaƒá siƒô do 1.0"
    
    processed_path = Path(processed_dir)
    output_path = Path(output_dir)
    
    # Tymczasowy katalog dla zbalansowanych danych
    balanced_path = output_path / "balanced_temp"
    if balanced_path.exists():
        shutil.rmtree(balanced_path)
    balanced_path.mkdir(parents=True, exist_ok=True)
    
    # Katalogi dla zbior√≥w finalnych
    train_dir = output_path / "train"
    val_dir = output_path / "val"
    test_dir = output_path / "test"
    
    for directory in [train_dir, val_dir, test_dir]:
        if directory.exists():
            shutil.rmtree(directory)
        directory.mkdir(parents=True, exist_ok=True)
    
    style_dirs = [d for d in processed_path.iterdir() if d.is_dir()]
    
    print("Rozpoczynam balansowanie klas...")
    
    balanced_files = {}
    
    # ETAP 1: Balansowanie ka≈ºdej klasy
    for style_dir in style_dirs:
        style_name = style_dir.name
        print(f"\nPrzetwarzanie stylu: {style_name}")
        
        balanced_style_dir = balanced_path / style_name
        balanced_style_dir.mkdir(exist_ok=True)
        
        image_files = list(style_dir.glob('*.jpg'))
        total_images = len(image_files)
        print(f"Znaleziono {total_images} oryginalnych obraz√≥w")
        
        balanced_files[style_name] = []
        
        # 1. Redukcja 
        if total_images > max_class_size:
            print(f"Redukcja liczby obraz√≥w z {total_images} do {max_class_size}")
            random.seed(random_state)
            selected_files = random.sample(image_files, max_class_size)
            
            for src_file in selected_files:
                dst_file = balanced_style_dir / src_file.name
                shutil.copy2(src_file, dst_file)
                balanced_files[style_name].append(dst_file)
                
        # 2. Oversamplig dla ma≈Çych klas
        elif total_images < min_class_size:
            print(f"Klasa {style_name} ma {total_images} obraz√≥w - mniej ni≈º {min_class_size}. Rozpoczynam uproszczony oversamplig.")
            
            for src_file in image_files:
                dst_file = balanced_style_dir / src_file.name
                shutil.copy2(src_file, dst_file)
                balanced_files[style_name].append(dst_file)
            
            oversample_needed = min_class_size - total_images
            
            copies_per_image = oversample_needed // total_images + 1
            
            rotation_angles = [-5, -3, 3, 5]  # kƒÖty obrotu
            # brightness_adjustments = [-25, -15, 15, 25] 
            
            augmented_count = 0
            for src_file in tqdm(image_files, desc="Augmentacja obraz√≥w"):
                for angle in rotation_angles:
                    if augmented_count >= oversample_needed:
                        break

                    img = cv2.imread(str(src_file))
                    if img is None:
                        continue
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                    h, w = img.shape[:2]
                    center = (w // 2, h // 2)
                    rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
                    rotated = cv2.warpAffine(img, rotation_matrix, (w, h))

                    augmented_name = f"{src_file.stem}_aug_r{angle}{src_file.suffix}"
                    augmented_path = balanced_style_dir / augmented_name
                    cv2.imwrite(str(augmented_path), cv2.cvtColor(rotated, cv2.COLOR_RGB2BGR))
                    balanced_files[style_name].append(augmented_path)
                    augmented_count += 1

            
            print(f"Utworzono {augmented_count} augmentowanych obraz√≥w")
            
        # 3. Odpowiednia wielko≈õƒá (kopiuje i tyle)
        else:
            print(f"Klasa {style_name} ma odpowiedniƒÖ liczbƒô obraz√≥w ({total_images}). Kopiowanie bez zmian.")
            for src_file in image_files:
                dst_file = balanced_style_dir / src_file.name
                shutil.copy2(src_file, dst_file)
                balanced_files[style_name].append(dst_file)
    
    print("\nZako≈Ñczono balansowanie klas. Dzielenie na zbiory...")
    
    # Stats
    dataset_stats = {
        'train': Counter(),
        'val': Counter(),
        'test': Counter()
    }
    
    # ETAP 2: Podzia≈Ç na zbiory treningowy, walidacyjny i testowy
    for style_name, files in balanced_files.items():
        print(f"Dzielenie klasy {style_name} na zbiory...")
        
        (train_dir / style_name).mkdir(exist_ok=True)
        (val_dir / style_name).mkdir(exist_ok=True)
        (test_dir / style_name).mkdir(exist_ok=True)
        
        random.seed(random_state)
        random.shuffle(files)
        
        train_size = int(len(files) * train_ratio)
        val_size = int(len(files) * val_ratio)
        
        train_files = files[:train_size]
        val_files = files[train_size:train_size+val_size]
        test_files = files[train_size+val_size:]
        
        for src_file in train_files:
            dst_file = train_dir / style_name / src_file.name
            shutil.copy2(src_file, dst_file)
        
        for src_file in val_files:
            dst_file = val_dir / style_name / src_file.name
            shutil.copy2(src_file, dst_file)
        
        for src_file in test_files:
            dst_file = test_dir / style_name / src_file.name
            shutil.copy2(src_file, dst_file)
        
        # Aktualizuj statystyki
        dataset_stats['train'][style_name] = len(train_files)
        dataset_stats['val'][style_name] = len(val_files)
        dataset_stats['test'][style_name] = len(test_files)
        
        print(f"  - Zbi√≥r treningowy: {dataset_stats['train'][style_name]} obraz√≥w")
        print(f"  - Zbi√≥r walidacyjny: {dataset_stats['val'][style_name]} obraz√≥w")
        print(f"  - Zbi√≥r testowy: {dataset_stats['test'][style_name]} obraz√≥w")
    
    shutil.rmtree(balanced_path)
    
    print("\nZako≈Ñczono podzia≈Ç danych na zbiory!")
    return dataset_stats


# Data distribution between categories

In [7]:
def plot_class_distribution(dataset_stats):
    """
    Tworzy wykresy ko≈Çowe przedstawiajƒÖce rozk≈Çad klas w ka≈ºdym ze zbior√≥w.
    """
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    
    for i, dataset_type in enumerate(['train', 'val', 'test']):
        stats = dataset_stats[dataset_type]
        labels = list(stats.keys())
        values = list(stats.values())
        
        axs[i].pie(values, labels=None, autopct='%1.1f%%', startangle=90)
        axs[i].set_title(f'Rozk≈Çad klas w zbiorze {dataset_type}')
    
    # Dodaj wsp√≥lnƒÖ legendƒô
    fig.legend(labels, loc='lower center', ncol=min(5, len(labels)))
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.15)
    plt.savefig('class_distribution.png')
    plt.show()
    
    # Tworzymy te≈º wykres s≈Çupkowy dla lepszego por√≥wnania
    df = pd.DataFrame(dataset_stats)
    df.plot(kind='bar', figsize=(12, 6))
    plt.title('Liczba obraz√≥w w ka≈ºdej klasie wed≈Çug zbioru')
    plt.xlabel('Klasa')
    plt.ylabel('Liczba obraz√≥w')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig('class_distribution_bar.png')
    plt.show()

def denormalize(img_tensor, mean, std):
    """
    Odwraca normalizacjƒô obrazu (tensor PyTorch) do zakresu [0,1].
    """
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    return img_tensor * std + mean

def plot_class_examples(dataset, num_examples=5):
    """
    Wy≈õwietla przyk≈Çadowe obrazy z ka≈ºdej klasy z oryginalnymi kolorami.
    """
    
    # Grupuj indeksy wed≈Çug klas
    class_indices = {}
    for idx, (_, label, _) in enumerate(dataset):
        if label not in class_indices:
            class_indices[label] = []
        class_indices[label].append(idx)
    
    # Dla ka≈ºdej klasy wy≈õwietl przyk≈Çady
    for label, indices in class_indices.items():
        class_name = dataset.idx_to_label[label]
        
        sample_indices = random.sample(indices, min(num_examples, len(indices)))
        
        fig, axs = plt.subplots(1, len(sample_indices), figsize=(15, 3))
        fig.suptitle(f'Przyk≈Çady klasy: {class_name}')
        
        for i, idx in enumerate(sample_indices):
            img, _, img_path = dataset[idx]
            
            # Konwersja z PyTorch tensor do formatu odpowiedniego dla matplotlib
            if isinstance(img, torch.Tensor):
                if img.dim() == 3 and img.shape[0] in [1, 3, 4]:  # je≈õli [C, H, W]
                    img = denormalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    img = img.permute(1, 2, 0).cpu().numpy()
                    img = np.clip(img, 0, 1)  # ≈ºeby warto≈õci nie wysz≈Çy poza zakres [0, 1]
                    
                    if img.shape[2] == 1:
                        img = img.squeeze(2)
                else:
                    img = img.cpu().numpy()
            
            if len(sample_indices) == 1:
                axs.imshow(img)
                axs.set_title(Path(img_path).name)
                axs.axis('off')
            else:
                axs[i].imshow(img)
                axs[i].set_title(Path(img_path).name)
                axs[i].axis('off')
        
        plt.tight_layout()
        plt.show()


# Dataloaders

In [8]:
def create_data_loaders(data_dir, batch_size=32, img_size=224, num_workers=4):
    def get_label_to_idx(data_dir):
        train_dir = Path(data_dir) / 'train'
        labels = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
        return {label: idx for idx, label in enumerate(labels)}

    # Sp√≥jna mapa etykiet
    label_to_idx = get_label_to_idx(data_dir)

    # Transformacje
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(img_size + 32),
        transforms.CenterCrop(img_size),
        transforms.RandomRotation(5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(img_size + 32),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Datasety z tƒÖ samƒÖ mapƒÖ etykiet
    train_dataset = ArtStyleDataset(data_dir, transform=train_transform, mode='train', label_to_idx=label_to_idx)
    val_dataset = ArtStyleDataset(data_dir, transform=val_test_transform, mode='val', label_to_idx=label_to_idx)
    test_dataset = ArtStyleDataset(data_dir, transform=val_test_transform, mode='test', label_to_idx=label_to_idx)

    # DataLoadery
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'test_dataset': test_dataset,
        'train_transform': train_transform,
        'val_test_transform': val_test_transform
    }


# Main

In [9]:
# ≈öcie≈ºki do katalog√≥w
input_directory = r"C:\Users\Dominik\Desktop\DataSet"
output_directory = r"C:\Users\Dominik\Desktop\outData"

# Parametry
target_size = (224, 224)
img_size = 224
batch_size = 32
min_class_size = 2000
max_class_size = 3000

# ≈öcie≈ºki wyj≈õciowe
processed_dir = os.path.join(output_directory, "processed")
split_dir = os.path.join(output_directory, "split")

In [10]:
#print("\n--- ETAP 1: Preprocessing obraz√≥w ---")
#preprocess_images(input_directory, processed_dir, target_size=target_size)

In [11]:
# print("\n--- ETAP 2: Podzia≈Ç na zbiory treningowy, walidacyjny i testowy z balansowaniem klas ---")
# dataset_stats = split_data_with_balancing(
#     processed_dir,
#     split_dir,
#     train_ratio=0.7,
#     val_ratio=0.15,
#     test_ratio=0.15,
#     min_class_size=min_class_size,
#     max_class_size=max_class_size
# )

In [12]:
# print("\n--- ETAP 3: Analiza rozk≈Çadu klas ---")
# plot_class_distribution(dataset_stats)

In [13]:
print("\n--- ETAP 4: Tworzenie DataLoader√≥w ---")
data_loaders = create_data_loaders(
    split_dir,
    batch_size=batch_size,
    img_size=img_size
)


--- ETAP 4: Tworzenie DataLoader√≥w ---
[train] Znaleziono 23351 obraz√≥w w 13 klasach.
[val] Znaleziono 5003 obraz√≥w w 13 klasach.
[test] Znaleziono 5007 obraz√≥w w 13 klasach.


In [14]:
# print("\n--- ETAP 5: Wizualizacja przyk≈Çad√≥w z ka≈ºdej klasy ---")
# plot_class_examples(data_loaders['train_dataset'])

# W≈Çasna architektura CNN

In [15]:
class ArtStyleCNN(nn.Module):
    def __init__(self, num_classes):
        super(ArtStyleCNN, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 112x112

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 56x56

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 28x28

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # wynik 1x1x512
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# Funkcja walidacyjna

In [16]:
def evaluate_model(model, val_loader, device):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for inputs, labels, _ in tqdm(val_loader, desc="üß™ Walidacja"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    acc = 100. * np.mean(np.array(y_true) == np.array(y_pred))
    print(f"üéØ Accuracy: {acc:.2f}%")

    # üìä Szczeg√≥≈Çowe metryki klasyfikacji
    class_names = [val_loader.dataset.idx_to_label[i] for i in range(len(val_loader.dataset.idx_to_label))]
    print("\nüìã Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=3))

    # üìâ Macierz pomy≈Çek
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(xticks_rotation=45, cmap="Blues")
    plt.title("Macierz pomy≈Çek (Validation Set)")
    plt.tight_layout()
    plt.show()

    return acc


# Funkcja treningowa

In [17]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):

    train_losses, val_accuracies = [], []

    torch.autograd.set_detect_anomaly(True)
    model = model.to(device)

    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        start_time = time.time()

        model.train()
        running_loss = 0.0
        correct, total = 0, 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}", unit="batch")

        for i, (inputs, labels, _) in enumerate(pbar):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            try:
                #print(f"Rozmiar batcha: {inputs.shape}")
                outputs = model(inputs)
                #print(f"Rozmiar outputu: {outputs.shape}")
                #print(f"labels.shape: {labels.shape}, dtype: {labels.dtype}")
                
                loss = criterion(outputs, labels)
                #print(f"Loss: {loss.item()}")  

                loss.backward()
                #print("Backward OK")

                optimizer.step()
                #print("Optimizer step OK")

            except Exception as e:
                print(f"B≈ÇƒÖd w batchu: {e}")
                continue

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            train_losses.append(running_loss / len(train_loader))

            # Dynamiczny update paska tqdm
            avg_loss = running_loss / (i + 1)
            acc = 100. * correct / total
            pbar.set_postfix({"Loss": f"{avg_loss:.4f}", "Train Acc": f"{acc:.2f}%"})

        train_acc = 100. * correct / total
        val_acc = evaluate_model(model, val_loader, device)
        val_accuracies.append(val_acc)

        
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pt")
        print(f"üíæ Zapisano model po epoce {epoch+1}")


        print(f"Epoch time: {time.time() - start_time:.2f}s")
        print(f"Train Loss: {running_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")
        torch.cuda.empty_cache()
        
        print(f"‚úîÔ∏è Zako≈Ñczono epokƒô {epoch+1}")

    # üìà Wykres strat i dok≈Çadno≈õci
    plt.figure(figsize=(12, 5))

    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train Loss", color='red')
    plt.title("Train Loss Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(val_accuracies, label="Validation Accuracy", color='blue')
    plt.title("Validation Accuracy Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()

    plt.tight_layout()
    plt.show()

    #return train_losses, val_accuracies


# Trenowanie

In [18]:
print("\nDane cuda:")
print(torch.__version__)                     # np. 2.6.0+cu118
print(torch.version.cuda)                    # np. '11.8' je≈õli CUDA dzia≈Ça
print(torch.cuda.is_available())             # True tylko je≈õli dzia≈Ça GPU
print("\n")


Dane cuda:
2.5.1
11.8
True




In [None]:
print("Liczba klas:", len(data_loaders['train_dataset'].label_to_idx))
print("Mapowanie klas:", data_loaders['train_dataset'].label_to_idx)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"U≈ºywane urzƒÖdzenie: {device}")

num_classes = len(data_loaders['train_dataset'].label_to_idx)
model = ArtStyleCNN(num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

try:
    train_model(model, 
                data_loaders['train_loader'], 
                data_loaders['val_loader'], 
                criterion, optimizer, 
                device, 
                num_epochs=10)
    print("‚úÖ Trenowanie zako≈Ñczone poprawnie.")
except Exception as e:
    print(f"‚ùå B≈ÇƒÖd w train_model: {e}")


Liczba klas: 13
Mapowanie klas: {'Academic_Art': 0, 'Art_Nouveau': 1, 'Baroque': 2, 'Expressionism': 3, 'Japanese_Art': 4, 'Neoclassicism': 5, 'Primitivism': 6, 'Realism': 7, 'Renaissance': 8, 'Rococo': 9, 'Romanticism': 10, 'Symbolism': 11, 'Western_Medieval': 12}
U≈ºywane urzƒÖdzenie: cuda

--- Epoch 1/10 ---


Epoch 1:   0%|          | 0/730 [00:00<?, ?batch/s]