In [1]:
# Multiclass Classification with ResNet-18 and wandb Logging

# Imports
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18
import torch.nn as nn
import torch.optim as optim
from PIL import Image, UnidentifiedImageError
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import wandb


In [2]:
#Hyperparameters
lr=0.0001
batch_size=32

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize wandb
wandb.init(project="handwritten-multiclass-resnet18")

# RandomPadToSize for preprocessing
class RandomPadToSize:
    def __init__(self, target_height, target_width):
        self.target_height = target_height
        self.target_width = target_width

    def __call__(self, img):
        width, height = img.size
        pad_left = (self.target_width - width) // 2
        pad_top = (self.target_height - height) // 2
        pad_right = self.target_width - width - pad_left
        pad_bottom = self.target_height - height - pad_top
        padding = (pad_left, pad_top, pad_right, pad_bottom)
        return transforms.functional.pad(img, padding, fill=255)


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33moscfah-1[0m ([33mertveh-4-lule-university-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:

# Transformations
transformT = transforms.Compose([
    RandomPadToSize(224, 224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Custom Dataset
class HandwritingDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.label_map = {
            "CLEAN": 0,
            "CROSS": 1,
            "DIAGONAL": 2,
            "DOUBLE_LINE": 3,
            "SCRATCH": 4,
            "SINGLE_LINE": 5,
            "WAVE": 6,
            "ZIG_ZAG": 7
        }
        valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
        for label in self.label_map:
            class_dir = os.path.join(root_dir, label)
            if not os.path.isdir(class_dir):
                continue
            for fname in os.listdir(class_dir):
                ext = os.path.splitext(fname)[1].lower()
                if ext in valid_extensions:
                    fpath = os.path.join(class_dir, fname)
                    try:
                        with Image.open(fpath) as img:
                            if min(img.size) > 29:
                                self.samples.append((fpath, self.label_map[label]))
                    except Exception:
                        print(f"Warning: Skipping unreadable file during init: {fpath}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if min(image.size) <= 29:
                raise ValueError("Image too small")
        except (UnidentifiedImageError, ValueError, OSError) as e:
            print(f"Warning: Skipping file {img_path} ({str(e)})")
            return self.__getitem__((idx + 1) % len(self.samples))  # Retry with next sample

        if self.transform:
            image = self.transform(image)
        return image, label

# Load datasets
train_dataset = HandwritingDataset(r"C:\Skola\D7047e\cross_out_dataset\train\images", transform=transformT)
val_dataset = HandwritingDataset(r"C:\Skola\D7047e\cross_out_dataset\val\images", transform=transformT)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define model
model = resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 8)  # 8 classes
model = model.to(DEVICE)


In [4]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop with validation
def train(num_epochs):
    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        all_preds, all_labels = [], []

        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_f1 = f1_score(all_labels, all_preds, average='weighted')

        # Validation
        model.eval()
        val_loss = 0.0
        val_preds, val_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_loss /= len(val_loader.dataset)
        val_acc = accuracy_score(val_labels, val_preds)
        val_f1 = f1_score(val_labels, val_preds, average='weighted')

        wandb.log({
            "Train Loss": epoch_loss,
            "Train Accuracy": epoch_acc,
            "Train F1": epoch_f1,
            "Validation Loss": val_loss,
            "Validation Accuracy": val_acc,
            "Validation F1": val_f1,
            "epoch": epoch + 1
        })

        print(f"Epoch [{epoch+1}/{num_epochs}]",
              f"Train Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}, F1: {epoch_f1:.4f} |",
              f"Val Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}, F1: {val_f1:.4f}")

# Run training
train(3)

# Save model
torch.save(model.state_dict(), "resnet18_multiclass.pth")



Epoch [1/3] Train Loss: 0.2473, Accuracy: 0.9010, F1: 0.9006 | Val Loss: 2.9068, Accuracy: 0.5916, F1: 0.5337
Epoch [2/3] Train Loss: 0.0949, Accuracy: 0.9643, F1: 0.9643 | Val Loss: 0.6298, Accuracy: 0.8515, F1: 0.8615
Epoch [3/3] Train Loss: 0.0637, Accuracy: 0.9766, F1: 0.9766 | Val Loss: 0.0503, Accuracy: 0.9812, F1: 0.9812
