# 1. Import dependencies

In [1]:
import os
import random
import yaml
import csv
import math

import torch
import torch.nn as nn
import torchvision
import numpy as np
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import wandb
import shutil
import gc
import timm

from functools import partial
from torchsummaryX import summary
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import functional as F
import torchvision.transforms.functional as TF
from tqdm import tqdm
from tqdm import tqdm

# 2. Configurations

In [3]:
config = {
    "model" : "seresnext",

    ###### Dataset -----------------------------------------------------------------
    "dataset_root" : "./data/VOC_PCB",
    "num_workers" : 12,
    "batch_size" : 256,

    ###### DataAugment ---------------------------------------------------------------

    ###### Loss function -------------------------------------------------------------
    "loss" : "crossEntropy",
    "ce_smoothing_factor" : 0.2,

    ###### Scheduler Parameters ------------------------------------------------------
    "scheduler" : "ReduceLR",  # ['ReduceLR', 'CosineAnnealing']
    "learning_rate" : 1e-3,
    "reducelr_factor" : 0.5,
    "reducelr_patience" : 3,
    "reducelr_threshold" : 1e-3, 
    "reducelr_minlr" : 1e-6,

    ###### Optimizer Parameters ------------------------------------------------------
    "optimizer" : "AdamW", # Adam, AdamW, SGD
    "weight_decay" : 0.01,


    ###### Training Parameters -------------------------------------------------------
    "use_wandb" : True,
    "dropout_rate" : 0.2,
    "epochs" : 100,
}

In [4]:
config_path = "./config.yaml"
with open(config_path, "w") as file:
    yaml.dump(config, file, default_flow_style=False, sort_keys=False)

In [5]:
DEFECT_CLASSES = {
            "mouse_bite": 0,
            "short": 1,
            "open_circuit": 2,
            "spur": 3,
            "missing_hole": 4,
            "spurious_copper": 5,
        }

In [6]:
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")

# 3. Data Loading

In [None]:
def count_images_for_splits(root):
    split_files = ['train', 'trainval', 'val', 'test']
    image_counts = {}

    for split in split_files:
        file_path = os.path.join(root, "ImageSets", "Main", f"{split}.txt")
        if os.path.exists(file_path):
            with open(file_path, 'r') as f:
                image_ids = f.read().splitlines()
                image_counts[split] = len(image_ids)
        else:
            image_counts[split] = 0  # If the file doesn't exist, set count to 0

    return image_counts

image_counts = count_images_for_splits(config["dataset_root"])
# Display the counts
for split, count in image_counts.items():
    print(f"Number of images in {split}: {count}")

## 3.1 Dataset

In [8]:
class PCBDataset(Dataset):
    def __init__(self, root, transforms=None, split="train"):
        self.root = root
        self.transforms = transforms
        self.split = split

        self.img_dir = os.path.join(root, "JPEGImages")
        self.ann_dir = os.path.join(root, "Annotations")

        # Load image IDs for the specified split
        img_ids = []
        if split == "train":
            with open(os.path.join(root, "ImageSets", "Main", "train.txt")) as f:
                img_ids += f.read().splitlines()
            with open(os.path.join(root, "ImageSets", "Main", "trainval.txt")) as f:
                img_ids += f.read().splitlines()
            self.img_ids = list(set(img_ids))
        else:
            with open(os.path.join(root, "ImageSets", "Main", f"{split}.txt")) as f:
                self.img_ids = f.read().splitlines()

        # Define defect classes, starting from 0
        self.defect_classes = DEFECT_CLASSES

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_path = os.path.join(self.img_dir, f"{img_id}.jpg")
        ann_path = os.path.join(self.ann_dir, f"{img_id}.xml")

        # Load image
        img = Image.open(img_path).convert("RGB")

        # Parse XML annotation file to get defect labels
        tree = ET.parse(ann_path)
        root = tree.getroot()

        labels = []

        # Collect all defect labels in the image
        for obj in root.findall("object"):
            defect_name = obj.find("name").text
            if defect_name in self.defect_classes:
                labels.append(self.defect_classes[defect_name])
            else:
                raise RuntimeError(f"Unexpected defect type {defect_name}")

        # Determine the most common defect label as the classification label
        if labels:
            label = max(set(labels), key=labels.count)  # Most common label in the image
        else:
            label = 0  # Assign a default label if no defect is found

        # Apply transforms if specified
        if self.transforms:
            img = self.transforms(img)

        return img, label

## 3.2 Data Augmentation

In [9]:
# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),  # ResNet-50 expects 224x224 input
    #transforms.Resize((600, 600)),  this is the original image resolution
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Standard normalization for ResNet
])


## 3.3 Data Loader

In [10]:
# Load datasets for train, val, and test splits
train_dataset = PCBDataset(root=config["dataset_root"], transforms=transform, split="train")
val_dataset = PCBDataset(root=config["dataset_root"], transforms=transform, split="val")
test_dataset = PCBDataset(root=config["dataset_root"], transforms=transform, split="test")

# Define DataLoaders without custom collate_fn
train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

In [None]:
print(f"Number of images in train dataset: {len(train_dataset)}")
print(f"Number of images in validation dataset: {len(val_dataset)}")
print(f"Number of images in test dataset: {len(test_dataset)}")

## 3.4 Data preview

In [None]:
# Load a subset of the training dataset for inspection
subset_dataset = torch.utils.data.Subset(train_dataset, range(10))
subset_loader = DataLoader(subset_dataset, batch_size=1)

# Iterate through the subset and print out sample information
for img, label in subset_loader:
    print("Sample loaded successfully.")
    print("Label:", label.item())
    break  # Only check the first sample

In [None]:
# Updated label mapping to match the zero-based indexing
label_map = {
    0: "Mouse Bites",
    1: "Shorts",
    2: "Open Circuits",
    3: "Spurs",
    4: "Missing Holes",
    5: "Spurious Coppers"
}

# Set grid dimensions
rows, cols = 5, 5  # Display a 5x5 grid of images

# Initialize plot
fig, ax = plt.subplots(rows, cols, figsize=(15, 15))
plt.subplots_adjust(hspace=0.5, wspace=0.3)  # Adjust spacing for label display

# Load a batch of samples (images and labels)
data_loader = DataLoader(train_dataset, batch_size=rows * cols, shuffle=True)
images, labels = next(iter(data_loader))

# Loop over the samples in the batch and plot each with labels
for idx in range(rows * cols):
    i, j = divmod(idx, cols)  # Determine grid position
    image = images[idx]
    label = labels[idx].item()

    # Convert to PIL image for proper display
    img = TF.to_pil_image(image)
    ax[i, j].imshow(img)
    ax[i, j].axis('off')  # Remove axes for a cleaner look

    # Display the human-readable label below each image
    ax[i, j].set_title(label_map.get(label, "Unknown"), fontsize=10, color="blue")

plt.tight_layout()
plt.show()

# 4. Model

In [14]:
class ResNetClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Load pretrained ResNet-50 model
        self.backbone = timm.create_model("resnet50", pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.backbone(x)

In [15]:
class EfficientNetClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Load pretrained ResNet-50 model
        self.backbone = timm.create_model("efficientnet_b0", pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.backbone(x)

In [16]:
class ViTClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Load pretrained ResNet-50 model
        self.backbone = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.backbone(x)

In [17]:
from typing import Any, Dict, List, Optional, Tuple, Type, Union

from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, LayerType, create_attn, \
    get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa, to_ntuple

In [18]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
            self,
            inplanes: int,
            planes: int,
            stride: int = 1,
            downsample: Optional[nn.Module] = None,
            cardinality: int = 1,
            base_width: int = 64,
            reduce_first: int = 1,
            dilation: int = 1,
            first_dilation: Optional[int] = None,
            act_layer: Type[nn.Module] = nn.ReLU,
            norm_layer: Type[nn.Module] = nn.BatchNorm2d,
            attn_layer: Optional[Type[nn.Module]] = None,
            aa_layer: Optional[Type[nn.Module]] = None,
            drop_block: Optional[Type[nn.Module]] = None,
            drop_path: Optional[nn.Module] = None,
    ):
        """
        Args:
            inplanes: Input channel dimensionality.
            planes: Used to determine output channel dimensionalities.
            stride: Stride used in convolution layers.
            downsample: Optional downsample layer for residual path.
            cardinality: Number of convolution groups.
            base_width: Base width used to determine output channel dimensionality.
            reduce_first: Reduction factor for first convolution output width of residual blocks.
            dilation: Dilation rate for convolution layers.
            first_dilation: Dilation rate for first convolution layer.
            act_layer: Activation layer.
            norm_layer: Normalization layer.
            attn_layer: Attention layer.
            aa_layer: Anti-aliasing layer.
            drop_block: Class for DropBlock layer.
            drop_path: Optional DropPath layer.
        """
        super(Bottleneck, self).__init__()

        width = int(math.floor(planes * (base_width / 64)) * cardinality)
        first_planes = width // reduce_first
        outplanes = planes * self.expansion
        first_dilation = first_dilation or dilation
        use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)

        self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
        self.bn1 = norm_layer(first_planes)
        self.act1 = act_layer(inplace=True)

        self.conv2 = nn.Conv2d(
            first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
            padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
        self.bn2 = norm_layer(width)
        self.drop_block = drop_block() if drop_block is not None else nn.Identity()
        self.act2 = act_layer(inplace=True)
        self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa)

        self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
        self.bn3 = norm_layer(outplanes)

        self.se = create_attn(attn_layer, outplanes)

        self.act3 = act_layer(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation
        self.drop_path = drop_path

    def zero_init_last(self):
        if getattr(self.bn3, 'weight', None) is not None:
            nn.init.zeros_(self.bn3.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.drop_block(x)
        x = self.act2(x)
        x = self.aa(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.se is not None:
            x = self.se(x)

        if self.drop_path is not None:
            x = self.drop_path(x)

        if self.downsample is not None:
            shortcut = self.downsample(shortcut)
        x += shortcut
        x = self.act3(x)

        return x

In [19]:
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding

def downsample_conv(
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
        first_dilation: Optional[int] = None,
        norm_layer: Optional[Type[nn.Module]] = None,
) -> nn.Module:
    norm_layer = norm_layer or nn.BatchNorm2d
    kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
    first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
    p = get_padding(kernel_size, stride, first_dilation)

    return nn.Sequential(*[
        nn.Conv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
        norm_layer(out_channels)
    ])


def downsample_avg(
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
        first_dilation: Optional[int] = None,
        norm_layer: Optional[Type[nn.Module]] = None,
) -> nn.Module:
    norm_layer = norm_layer or nn.BatchNorm2d
    avg_stride = stride if dilation == 1 else 1
    if stride == 1 and dilation == 1:
        pool = nn.Identity()
    else:
        avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
        pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)

    return nn.Sequential(*[
        pool,
        nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
        norm_layer(out_channels)
    ])


def drop_blocks(drop_prob: float = 0.):
    return [
        None, None,
        partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None,
        partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None]


def make_blocks(
        block_fns: Bottleneck,
        channels: Tuple[int, ...],
        block_repeats: Tuple[int, ...],
        inplanes: int,
        reduce_first: int = 1,
        output_stride: int = 32,
        down_kernel_size: int = 1,
        avg_down: bool = False,
        drop_block_rate: float = 0.,
        drop_path_rate: float = 0.,
        **kwargs,
) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
    stages = []
    feature_info = []
    net_num_blocks = sum(block_repeats)
    net_block_idx = 0
    net_stride = 4
    dilation = prev_dilation = 1
    for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))):
        stage_name = f'layer{stage_idx + 1}'  # never liked this name, but weight compat requires it
        stride = 1 if stage_idx == 0 else 2
        if net_stride >= output_stride:
            dilation *= stride
            stride = 1
        else:
            net_stride *= stride

        downsample = None
        if stride != 1 or inplanes != planes * block_fn.expansion:
            down_kwargs = dict(
                in_channels=inplanes,
                out_channels=planes * block_fn.expansion,
                kernel_size=down_kernel_size,
                stride=stride,
                dilation=dilation,
                first_dilation=prev_dilation,
                norm_layer=kwargs.get('norm_layer'),
            )
            downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)

        block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
        blocks = []
        for block_idx in range(num_blocks):
            downsample = downsample if block_idx == 0 else None
            stride = stride if block_idx == 0 else 1
            block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1)  # stochastic depth linear decay rule
            blocks.append(block_fn(
                inplanes,
                planes,
                stride,
                downsample,
                first_dilation=prev_dilation,
                drop_path=DropPath(block_dpr) if block_dpr > 0. else None,
                **block_kwargs,
            ))
            prev_dilation = dilation
            inplanes = planes * block_fn.expansion
            net_block_idx += 1

        stages.append((stage_name, nn.Sequential(*blocks)))
        feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))

    return stages, feature_info

In [20]:
class SEResNext(nn.Module):
    def __init__(
            self,
            block: Bottleneck,
            layers: Tuple[int, ...],
            num_classes: int = 1000,
            in_chans: int = 3,
            output_stride: int = 32,
            cardinality: int = 1,
            base_width: int = 64,
            block_reduce_first: int = 1,
            down_kernel_size: int = 1,
            channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
            drop_rate: float = 0.0,
            zero_init_last: bool = True,
            block_args: Optional[Dict[str, Any]] = None,
    ):
        super(SEResNext, self).__init__()
        block_args = block_args or dict()
        assert output_stride in (8, 16, 32)
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.grad_checkpointing = False
        
        stem_width: int = 64
        stem_type: str = ''

        act_layer = nn.ReLU
        norm_layer = nn.BatchNorm2d

        # Stem
        deep_stem = 'deep' in stem_type
        inplanes = stem_width * 2 if deep_stem else 64
        if deep_stem:
            stem_chs = (stem_width, stem_width)
            if 'tiered' in stem_type:
                stem_chs = (3 * (stem_width // 4), stem_width)
            self.conv1 = nn.Sequential(*[
                nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
                norm_layer(stem_chs[0]),
                act_layer(inplace=True),
                nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
                norm_layer(stem_chs[1]),
                act_layer(inplace=True),
                nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
        else:
            self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(inplanes)
        self.act1 = act_layer(inplace=True)
        self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Feature Blocks
        block_fns = to_ntuple(len(channels))(block)
        stage_modules, stage_feature_info = make_blocks(
            block_fns,
            channels,
            layers,
            inplanes,
            cardinality=cardinality,
            base_width=base_width,
            output_stride=output_stride,
            reduce_first=block_reduce_first,
            avg_down=False,
            down_kernel_size=down_kernel_size,
            act_layer=act_layer,
            norm_layer=norm_layer,
            aa_layer=None,
            drop_block_rate=0.,
            drop_path_rate=0.,
            **block_args,
        )
        for stage in stage_modules:
            self.add_module(*stage)  # layer1, layer2, etc
        self.feature_info.extend(stage_feature_info)

        # Head (Pooling and Classifier)
        self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
        self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type='avg')

        self.init_weights(zero_init_last=zero_init_last)

    def init_weights(self, zero_init_last: bool = True):
        for n, m in self.named_modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if zero_init_last:
            for m in self.modules():
                if hasattr(m, 'zero_init_last'):
                    m.zero_init_last()

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x

    def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
        x = self.global_pool(x)
        if self.drop_rate:
            x = F.dropout(x, p=float(self.drop_rate), training=self.training)
        return x if pre_logits else self.fc(x)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.forward_features(x)
        x = self.forward_head(x)
        return x

In [25]:
class SEResNextClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        # Load pretrained ResNet-50 model
        self.backbone = timm.create_model("seresnext101_32x8d", pretrained=True, num_classes=num_classes)

    def forward(self, x):
        return self.backbone(x)

In [None]:
# Specify the number of classes
num_classes = len(DEFECT_CLASSES)

if config["model"] == "resnet":
    model = ResNetClassifier(num_classes)
elif config["model"] == "efficientnet":
    model = EfficientNetClassifier(num_classes)
elif config["model"] == "vit":
    model = ViTClassifier(num_classes)
elif config["model"] == "seresnext":
    model_args = dict(
        block=Bottleneck, 
        layers=(3, 4, 23, 3), 
        cardinality=32, 
        base_width=4,
        block_args=dict(attn_layer='se'))
    model = SEResNext(**model_args, num_classes=num_classes)
    model = SEResNextClassifier(num_classes)

# Move the model to GPU if available
model.to(device)

In [None]:
for x, y in DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True):
    # Print the model structure
    summary(model, x.to(device))
    break

In [23]:
# Uncomment the line for saving the scheduler save dict if you are using a scheduler
def save_model(model, optimizer, scheduler, metrics, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         'metric'                   : metrics,
         'epoch'                    : epoch},
         path)


def load_model(model, optimizer=None, scheduler=None, path='./checkpoint.pth'):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = None
    epoch = checkpoint['epoch']
    metrics = checkpoint['metric']
    return model, optimizer, scheduler, epoch, metrics

# 5. Training Components

## 5.1 Loss Function

In [25]:
if config["loss"] == "crossEntropy":
    classification_criterion = nn.CrossEntropyLoss(label_smoothing=config["ce_smoothing_factor"])
else:
    raise NotImplementedError

## 5.2 Optimizer

In [26]:
if config["optimizer"] == "AdamW":
    # AdamW
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
else:
    raise NotImplementedError

## 5.3 Learning Rate Scheduler

In [27]:
if config["scheduler"] == "ReduceLR":
    # ReduceLROnPlateau scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=config["reducelr_factor"], 
        patience=config["reducelr_patience"], 
        threshold=config["reducelr_threshold"], 
        min_lr=config["reducelr_minlr"])
else:
    raise NotImplementedError

## 5.4 wandb

In [None]:
USE_WANDB = config['use_wandb']
RESUME_LOGGING = False

run_name = "{}".format(
        config["model"],
    )

expt_root = os.path.join(os.getcwd(), "exp", run_name)
os.makedirs(expt_root, exist_ok=True)

if USE_WANDB:
    wandb.login(key="", relogin=True)

    if RESUME_LOGGING:
        resume_id = "test"
        run = wandb.init(
            id     = resume_id,        ### Insert specific run id here if you want to resume a previous run
            resume = True,          ### You need this to resume previous runs, but comment out reinit=True when using this
            project = "project",  ### Project should be created in your wandb account
        )

    else:
        run = wandb.init(
            name    = run_name,     ### Wandb creates random run names if you skip this field, we recommend you give useful names
            reinit  = True,         ### Allows reinitalizing runs when you re-run this cell
            project = "project",  ### Project should be created in your wandb account
            config  = config        ### Wandb Config for your run
        )

        ### Save your model architecture as a string with str(model)
        model_arch  = str(model)
        ### Save it in a txt file
        model_path = os.path.join(expt_root, "model_arch.txt")
        arch_file   = open(model_path, "w")
        file_write  = arch_file.write(model_arch)
        arch_file.close()

        ### Log it in your wandb run with wandb.sav


### Create a local directory with all the checkpoints
shutil.copy(os.path.join(os.getcwd(), config_path), os.path.join(expt_root, 'config.yaml'))
e                   = 0
best_loss           = 1.2
best_perplexity     = 23.0
best_dist = 60
RESUME_LOGGING = False
checkpoint_root = os.path.join(expt_root, 'checkpoints')
text_root       = os.path.join(expt_root, 'out_text')
os.makedirs(checkpoint_root, exist_ok=True)
os.makedirs(text_root,       exist_ok=True)
checkpoint_best_loss_model_filename     = 'checkpoint-best-loss-modelfull.pth'
checkpoint_last_epoch_filename          = 'checkpoint-epochfull-'
best_loss_model_path                    = os.path.join(checkpoint_root, checkpoint_best_loss_model_filename)

if USE_WANDB:
    wandb.watch(model, log="all")
    wandb.save(config_path)

if RESUME_LOGGING:
    # change if you want to load best test model accordingly
    checkpoint = torch.load(wandb.restore(checkpoint_best_loss_model_filename, run_path=""+resume_id).name)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    e = checkpoint['epoch']

    print("Resuming from epoch {}".format(e+1))
    print("Epochs left: ", config['epochs']-e)
    print("Optimizer: \n", optimizer)

torch.cuda.empty_cache()
gc.collect()

# 6. Functions

## 6.1 Metrics

In [29]:
class metric_accuracy:
    def __init__(self):
        self.total = 0
        self.total_correct = 0

        self.correct_by_class = [0] * len(DEFECT_CLASSES)
        self.total_by_class = [0] * len(DEFECT_CLASSES)

        self.accuracy = 0
        self.accuracy_by_class = [0] * len(DEFECT_CLASSES)

    def update(self, outputs, labels):
        self.total += outputs.shape[0]

        # Calculate overall correct predictions
        preds = outputs.argmax(dim=1)
        self.total_correct += (preds == labels).sum().item()

        num_classes = len(DEFECT_CLASSES)
        # Calculate per-class correct predictions
        for label in range(num_classes):
            self.correct_by_class[label] += ((preds == label) & (labels == label)).sum().item()
            self.total_by_class[label] += (labels == label).sum().item()

        self.accuracy = 100 * self.total_correct / self.total

        # Calculate per-class accuracy
        for i in range(num_classes):
            class_correct = self.correct_by_class[i]
            class_total = self.total_by_class[i] if self.total_by_class[i] != 0 else 1
            self.accuracy_by_class[i] = class_correct / class_total
    
    def get(self):
        return self.accuracy, self.accuracy_by_class

## 6.1 Training

In [30]:
# Structures to track metrics over epochs
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
class_names = ["Mouse Bites", "Shorts", "Open Circuits", "Spurs", "Missing Holes", "Spurious Coppers"]

def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    # init metrics
    m_acc = metric_accuracy()

    # Initialize progress bar
    progress_bar = tqdm(train_loader, desc="Training", leave=False)

    for imgs, labels in progress_bar:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += loss.item()

        m_acc.update(outputs, labels)

        # Update progress bar with overall loss and accuracy
        progress_bar.set_postfix(loss=loss.item(), accuracy=f"{m_acc.get()[0]:.2f}%")

    # Final calculations for epoch loss and accuracy
    epoch_loss = running_loss / len(train_loader)

    return epoch_loss, *m_acc.get()

## 6.2 Validation

In [31]:
def evaluate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0

    # Metrics
    m_acc = metric_accuracy()

    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            # Accumulate loss
            running_loss += loss.item()

            m_acc.update(outputs, labels)

    # Final calculations for validation loss and accuracy
    val_loss = running_loss / len(val_loader)
    overall_accuracy, accuracy_by_class = m_acc.get()

    return val_loss, overall_accuracy, accuracy_by_class

# 7. Train

In [None]:
# Tracking metrics for plotting
train_class_accuracies = {label: [] for label in class_names}
val_class_accuracies = {label: [] for label in class_names}
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
epochs = config["epochs"]

for epoch in range(e, epochs):
    gc.collect()
    torch.cuda.empty_cache()

    print(f"Epoch {epoch+1}/{epochs}")

    # Train for one epoch
    epoch_loss, train_accuracy, train_accuracy_by_class = train_one_epoch(model, train_loader, optimizer, classification_criterion, device)
    train_losses.append(epoch_loss)
    train_accuracies.append(train_accuracy)
    for i, acc in enumerate(train_accuracy_by_class):
        train_class_accuracies[class_names[i]].append(acc)

    # Evaluate on validation set
    val_loss, val_accuracy, val_accuracy_by_class = evaluate(model, val_loader, classification_criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)
    for i, acc in enumerate(val_accuracy_by_class):
        val_class_accuracies[class_names[i]].append(acc)
    
    if config["scheduler"] == "ReduceLR":
        scheduler.step(-val_accuracy)
    else:
        scheduler.step()

    curr_lr = float(optimizer.param_groups[0]["lr"])

    # train dict
    train_acc_dict = {class_name+"_train_acc":train_accuracy_by_class[idx] for class_name, idx in DEFECT_CLASSES.items()}
    val_acc_dict = {class_name+"_val_acc":val_accuracy_by_class[idx] for class_name, idx in DEFECT_CLASSES.items()}
    
    if USE_WANDB:
        wandb.log({
            "train_loss"       : epoch_loss,
            "train_accuracy"   : train_accuracy,
            "val_loss"         : val_loss,
            "val_accuracy"     : val_accuracy,
            "learning_rate"    : curr_lr,
            **train_acc_dict,
            **val_acc_dict,
        })

    print(f"Epoch [{epoch+1}/{epochs}] - Training Loss: {epoch_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")
    print(f"Epoch [{epoch+1}/{epochs}] - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    print(f"Training Accuracy by Class: {train_accuracy_by_class}")
    print(f"Validation Accuracy by Class: {val_accuracy_by_class}")

In [None]:
run.finish()

In [None]:
# Plot Training and Validation Loss
plt.figure(figsize=(12, 5))
plt.plot(range(1, epoch+1), train_losses, label="Training Loss")
plt.plot(range(1, epoch+1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Loss Over Epochs")
plt.show()

# Plot Training and Validation Overall Accuracy
plt.figure(figsize=(12, 5))
plt.plot(range(1, epoch+1), train_accuracies, label="Training Accuracy")
plt.plot(range(1, epoch+1), val_accuracies, label="Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.title("Training and Validation Overall Accuracy Over Epochs")
plt.show()

# Plot Per-Class Accuracy for Training
plt.figure(figsize=(12, 6))
for class_name in class_names:
    plt.plot(range(1, epoch+1), train_class_accuracies[class_name], label=f"{class_name} (Train)")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.title("Per-Class Training Accuracy Over Epochs")
plt.show()

# Plot Per-Class Accuracy for Validation
plt.figure(figsize=(12, 6))
for class_name in class_names:
    plt.plot(range(1, epoch+1), val_class_accuracies[class_name], label=f"{class_name} (Val)")
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.title("Per-Class Validation Accuracy Over Epochs")
plt.show()

In [39]:
def test(model, test_loader, device, num_classes=6):
    """
    Evaluates the model on the test set, showing a progress bar and returning overall and per-class accuracy and loss.
    """
    class_names = ["Mouse Bites", "Shorts", "Open Circuits", "Spurs", "Missing Holes", "Spurious Coppers"]
    model.eval()
    running_loss = 0.0
    correct = [0] * num_classes
    total = [0] * num_classes
    class_losses = [0.0] * num_classes  # Store cumulative loss for each class
    overall_correct = 0
    overall_total = 0
    criterion = nn.CrossEntropyLoss(reduction='none')  # Use non-reduced loss for per-class calculation

    # Initialize progress bar
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Testing", leave=False)
        for imgs, labels in progress_bar:
            imgs, labels = imgs.to(device), labels.to(device)

            # Forward pass
            outputs = model(imgs)
            loss = criterion(outputs, labels)  # Calculate loss for each sample

            # Accumulate overall loss
            running_loss += loss.sum().item()  # Total loss over all samples

            # Calculate overall correct predictions
            preds = outputs.argmax(dim=1)
            overall_correct += (preds == labels).sum().item()
            overall_total += labels.size(0)

            # Calculate per-class correct predictions and accumulate per-class loss
            for label in range(num_classes):
                class_mask = (labels == label)  # Mask for samples of the current class
                correct[label] += ((preds == label) & class_mask).sum().item()
                total[label] += class_mask.sum().item()

                # Calculate per-class loss by summing the loss for the samples of this class
                class_losses[label] += loss[class_mask].sum().item()

            # Update progress bar with overall loss and accuracy
            current_loss = running_loss / overall_total if overall_total > 0 else 0
            current_accuracy = 100 * overall_correct / overall_total if overall_total > 0 else 0
            progress_bar.set_postfix(loss=current_loss, accuracy=f"{current_accuracy:.2f}%")

    # Calculate average losses and final metrics
    overall_loss = running_loss / overall_total if overall_total > 0 else 0
    overall_accuracy = 100 * overall_correct / overall_total if overall_total > 0 else 0
    per_class_accuracy = [100 * correct[i] / total[i] if total[i] > 0 else 0 for i in range(num_classes)]
    per_class_loss = [class_losses[i] / total[i] if total[i] > 0 else 0 for i in range(num_classes)]

    print(f"Overall Test Loss: {overall_loss:.4f}")
    print(f"Overall Test Accuracy: {overall_accuracy:.2f}%")
    print("\nPer-Class Results:")
    for i in range(num_classes):
        print(f"{class_names[i]} - Accuracy: {per_class_accuracy[i]:.2f}%, Loss: {per_class_loss[i]:.4f}")

    return overall_loss, overall_accuracy, per_class_accuracy, per_class_loss


In [None]:
# Test the model and display results
test_loss, test_accuracy, test_per_class_accuracy, test_per_class_loss = test(model, test_loader, device)

# Define the class names corresponding to the defect types
class_names = ["Mouse Bites", "Shorts", "Open Circuits", "Spurs", "Missing Holes", "Spurious Coppers"]

# Plot Per-Class Test Accuracy
plt.figure(figsize=(8, 6))
bars = plt.bar(class_names, test_per_class_accuracy, color='skyblue')
# Add data labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,  # X-coordinate (center of the bar)
        height,                             # Y-coordinate (top of the bar)
        f'{height:.2f}',                    # Text to display (formatted to 2 decimal places)
        ha='center',                        # Horizontal alignment
        va='bottom'                         # Vertical alignment
    )
# Add labels and title
plt.xlabel("Classes")
plt.ylabel("Accuracy (%)")
plt.title("Per-Class Test Accuracy")
plt.xticks(rotation=45, ha='right')  # Rotate labels for better readability
plt.tight_layout()  # Adjust layout to fit rotated labels
plt.show()

# Plot Per-Class Test Loss
plt.figure(figsize=(8, 6))
bars = plt.bar(class_names, test_per_class_loss, color='salmon')

# Add data labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,  # X-coordinate (center of the bar)
        height,                             # Y-coordinate (top of the bar)
        f'{height:.2f}',                    # Text to display (formatted to 2 decimal places)
        ha='center',                        # Horizontal alignment
        va='bottom'                         # Vertical alignment
    )

# Add labels and title
plt.xlabel("Classes")
plt.ylabel("Loss")
plt.title("Per-Class Test Loss")
plt.xticks(rotation=45, ha='right')  # Rotate labels for better readability
plt.tight_layout()  # Adjust layout to fit rotated labels
plt.show()