In [None]:
from tqdm import tqdm
# Imports and constants
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import os
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import Subset
import random


BATCH_SIZE = 64
LEARNING_RATE = 0.0001
EPOCHS = 15

class CrossoutDataset(Dataset):
    def __init__(self, folder_root, transform=None):
        self.base = ImageFolder(root=folder_root, transform=transform)
        self.clean_idx = self.base.class_to_idx['CLEAN']
        self.mixed_idx = self.base.class_to_idx['MIXED']
        self.cross_types = [c for c in self.base.classes if c not in ('CLEAN', 'MIXED')]
        self.type2idx = {c: i for i, c in enumerate(self.cross_types)}

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, class_idx = self.base[idx]
        folder_name = self.base.classes[class_idx]
        is_crossed = 0 if class_idx == self.clean_idx else 1
        type_label = self.type2idx.get(folder_name, 0)
        return img, is_crossed, type_label

# Data loading function
def get_data_loaders(batch_size, data_root, sample_fraction=0.2, num_workers=4):
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.Resize((224, 224)),  # For ResNet or larger models
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    train_dir = os.path.join(data_root, "train/images")
    val_dir = os.path.join(data_root, "val/images")
    test_dir = os.path.join(data_root, "test/images")

    train_dataset = CrossoutDataset(train_dir, transform=train_transform)
    val_dataset = CrossoutDataset(val_dir, transform=test_transform)
    test_dataset = CrossoutDataset(test_dir, transform=test_transform)

    # Reduce training dataset size for faster runs
    if sample_fraction < 1.0:
        train_size = int(len(train_dataset) * sample_fraction)
        indices = random.sample(range(len(train_dataset)), train_size)
        train_dataset = Subset(train_dataset, indices)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    validationloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return trainloader, validationloader, testloader

# Model definition
from torchvision.models import resnet18
import torch.nn.functional as F

class CrossoutResNet(nn.Module):
    def __init__(self):
        super(CrossoutResNet, self).__init__()
        self.backbone = resnet18(weights=ResNet18_Weights.DEFAULT)

        # Freeze pretrained layers
        # Freeze all layers first
        for param in self.backbone.parameters():
            param.requires_grad = False

        # Then unfreeze last few layers (e.g., layer4)
        for param in self.backbone.layer4.parameters():
            param.requires_grad = True


        # Remove original classifier (fc layer)
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])  # [B, 512, 1, 1]

        # Add custom classification heads
        self.fc1 = nn.Linear(512, 128)
        self.fc_bin = nn.Linear(128, 2)   # Binary: CLEAN vs CROSSED
        self.fc_type = nn.Linear(128, 7)  # Multi-class: 7 cross types

    def forward(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), -1)         # Flatten to [B, 512]
        x = F.relu(self.fc1(x))
        return self.fc_bin(x), self.fc_type(x)

# Training function
def train_model(model, criterion_bin, criterion_class, optimizer,
                trainloader, validationloader, num_epochs, device, lambda_type=1.0):
    train_losses = []
    val_losses = []
    model.train()
    for epoch in range(num_epochs):
        total_loss, correct_bin, correct_type, total, valid_type_total = 0, 0, 0, 0, 0

        for inputs, is_crossed, type_label in tqdm(trainloader, desc=f"Training Epoch {epoch+1}"):
            inputs, is_crossed, type_label = inputs.to(device), is_crossed.to(device), type_label.to(device)
            optimizer.zero_grad()
            out_bin, out_type = model(inputs)
            loss_bin = nn.CrossEntropyLoss()(out_bin, is_crossed)
            mask = (is_crossed == 1)
            loss_type = nn.CrossEntropyLoss()(out_type[mask], type_label[mask]) if mask.any() else torch.tensor(0.0, device=device)
            loss = loss_bin + lambda_type * loss_type
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total += is_crossed.size(0)
            correct_bin += (out_bin.argmax(1) == is_crossed).sum().item()
            if mask.any():
                correct_type += (out_type[mask].argmax(1) == type_label[mask]).sum().item()
                valid_type_total += mask.sum().item()

        bin_acc = correct_bin / total
        type_acc = correct_type / valid_type_total if valid_type_total else 0.0
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}, Binary Acc: {bin_acc:.4f}, Type Acc: {type_acc:.4f}")
        train_losses.append(total_loss)

        # Validation phase
        model.eval()
        val_loss, val_correct_bin, val_correct_type, val_total, val_type_total = 0, 0, 0, 0, 0
        val_type_correct = {}
        val_type_total_dict = {}
        for i in range(len(validationloader.dataset.cross_types)):
            val_type_correct[i] = 0
            val_type_total_dict[i] = 0

        with torch.no_grad():
            for inputs, is_crossed, type_label in tqdm(validationloader, desc=f"Validation Epoch {epoch+1}"):
                inputs, is_crossed, type_label = inputs.to(device), is_crossed.to(device), type_label.to(device)
                out_bin, out_type = model(inputs)
                loss_bin = nn.CrossEntropyLoss()(out_bin, is_crossed)
                mask = (is_crossed == 1)
                loss_type = nn.CrossEntropyLoss()(out_type[mask], type_label[mask]) if mask.any() else torch.tensor(0.0, device=device)
                loss = loss_bin + lambda_type * loss_type
                val_loss += loss.item()
                val_total += is_crossed.size(0)
                val_correct_bin += (out_bin.argmax(1) == is_crossed).sum().item()
                if mask.any():
                    preds = out_type[mask].argmax(1)
                    labels = type_label[mask]
                    val_correct_type += (preds == labels).sum().item()
                    val_type_total += mask.sum().item()
                    for t, p in zip(labels, preds):
                        t = t.item()
                        val_type_total_dict[t] += 1
                        if p.item() == t:
                            val_type_correct[t] += 1

        val_losses.append(val_loss)
        val_bin_acc = val_correct_bin / val_total
        val_type_acc = val_correct_type / val_type_total if val_type_total else 0.0
        print(f"Validation - Loss: {val_loss:.4f}, Binary Acc: {val_bin_acc:.4f}, Type Acc: {val_type_acc:.4f}")
        print("Validation - Per-type Accuracy:")
        for i in range(len(validationloader.dataset.cross_types)):
            acc = val_type_correct[i] / val_type_total_dict[i] if val_type_total_dict[i] > 0 else 0
            print(f"  {validationloader.dataset.cross_types[i]}: {acc:.4f}")
        train_losses = []
    val_losses = []
    model.train()

    return model

def evaluate(model, testloader):
    model.eval()
    correct_bin, correct_type, total, valid_type_total = 0, 0, 0, 0

    y_true_binary = []
    y_pred_binary = []
    y_true_type = []
    y_pred_type = []

    with torch.no_grad():
        for inputs, is_crossed, type_label in testloader:
            inputs, is_crossed, type_label = inputs.to(device), is_crossed.to(device), type_label.to(device)
            out_bin, out_type = model(inputs)

            total += is_crossed.size(0)
            correct_bin += (out_bin.argmax(1) == is_crossed).sum().item()

            # Save binary predictions
            y_true_binary.extend(is_crossed.cpu().numpy())
            y_pred_binary.extend(out_bin.argmax(1).cpu().numpy())

            # Only evaluate type classification for crossed images
            mask = (is_crossed == 1)
            if mask.any():
                correct_type += (out_type[mask].argmax(1) == type_label[mask]).sum().item()
                valid_type_total += mask.sum().item()

                # Save type predictions
                y_true_type.extend(type_label[mask].cpu().numpy())
                y_pred_type.extend(out_type[mask].argmax(1).cpu().numpy())

    bin_acc = correct_bin / total
    type_acc = correct_type / valid_type_total if valid_type_total else 0.0

    print(f"Test Binary Accuracy (CLEAN vs CROSSED): {bin_acc:.4f}")
    print(f"Test Type Accuracy (on crossed only): {type_acc:.4f}")

    # Plot Binary Confusion Matrix (CLEAN vs CROSSED)
    cm_bin = confusion_matrix(y_true_binary, y_pred_binary)
    disp_bin = ConfusionMatrixDisplay(confusion_matrix=cm_bin, display_labels=['CLEAN', 'CROSSED'])
    fig, ax = plt.subplots(figsize=(6,6))
    disp_bin.plot(ax=ax, cmap='Blues', values_format='.2g')
    plt.title("Confusion Matrix - Binary (CLEAN vs CROSSED)")
    plt.show()

    # Plot Type Confusion Matrix (only crossed)
    if y_true_type and y_pred_type:
        all_labels = [
            'MIXED', 'CROSS', 'DIAGONAL', 'DOUBLE_LINE',
            'SCRATCH', 'SINGLE_LINE', 'WAVE', 'ZIG_ZAG'
        ]

        classes_in_data = np.unique(np.concatenate([y_true_type, y_pred_type]))
        labels_used = [all_labels[i] for i in classes_in_data]

        cm_type = confusion_matrix(y_true_type, y_pred_type, labels=classes_in_data)
        disp_type = ConfusionMatrixDisplay(confusion_matrix=cm_type, display_labels=labels_used)
        fig, ax = plt.subplots(figsize=(10,8))
        disp_type.plot(ax=ax, cmap='Blues', values_format='.2g')
        plt.title("Confusion Matrix - Crossed Types Only")
        plt.show()

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    data_root = r"D:\PLUGGET\D7047E\Project\Project\data\cross_out_dataset"
    trainloader, validationloader, testloader = get_data_loaders(
        batch_size=64,
        data_root=data_root,
        sample_fraction=0.2,
        num_workers=4
    )

    model = CrossoutResNet().to(device)
    criterion_bin = nn.CrossEntropyLoss()
    criterion_class = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    trained_model = train_model(
        model,
        criterion_bin,
        criterion_class,
        optimizer,
        trainloader,
        validationloader,
        num_epochs=10,
        device=device,
        lambda_type=1.0
    )

    evaluate(trained_model, testloader)

Was ran on local machine so results are not here