In [25]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from tqdm.auto import tqdm 

import torch 
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
import torch.optim as optim 

from torchvision.models import resnet50 , ResNet50_Weights


In [26]:
# class basicallly for extracting the feature maps from the pretrained model Resnet50
class resnet_feature_extractor(torch.nn.Module):
    def __init__(self):
        super(resnet_feature_extractor,self).__init__()
        self.model = resnet50(weights=ResNet50_Weights.DEFAULT)

        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False


        #hook to extract the feature maps

        def hook(module, input ,output)-> None:
            # this hook saves the extracted features map on self.forward
            self.features.append(output.detach())

        self.model.layer2[-1].register_forward_hook(hook)
        self.model.layer3[-1].register_forward_hook(hook)

    def forward(self,input):
        self.features=[]
        with torch.no_grad():  # torch.no_grad() is for the It temporarily turns off gradient calculation, meaning: No .grad is stored , No computation graph is built , Much lower memory use Faster inference
            _=self.model(input)

        self.avg=torch.nn.AvgPool2d(3,stride=1)
        fmap_size=self.features[0].shape[-2] #feature map sizes h,w 
        self.resize = torch.nn.AdaptiveAvgPool2d(fmap_size)

        resized_maps=[self.resize(self.avg(fmap)) for fmap in self.features]
        patch=torch.cat(resized_maps,1)

        return patch

In [27]:
import torch
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights

class ResNetFeatureExtractorFast(torch.nn.Module):
    def __init__(self, pretrained=True, device=None):
        super().__init__()
        weights = ResNet50_Weights.DEFAULT if pretrained else None
        self.model = resnet50(weights=weights)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad = False

        # Register hooks once (optional) or you can call layers directly (see below)
        self.features = []
        self.handles = []
        self.handles.append(self.model.layer2[-1].register_forward_hook(self._hook))
        self.handles.append(self.model.layer3[-1].register_forward_hook(self._hook))

        if device is not None:
            self.to(device)

    def _hook(self, module, input, output):
        # keep on-device detached copy
        self.features.append(output.detach())

    def forward(self, x, target_spatial=None):
        # clear
        self.features = []

        # run backbone (no grad)
        with torch.no_grad():
            _ = self.model(x)

        if len(self.features) == 0:
            raise RuntimeError("No features captured. Are hooks registered?")

        # target spatial default = layer2 size
        if target_spatial is None:
            h = self.features[0].shape[-2]
            w = self.features[0].shape[-1]
            target_spatial = (h, w)

        resized = []
        for fmap in self.features:
            # smooth (keep same size)
            fmap_smoothed = F.avg_pool2d(fmap, kernel_size=3, stride=1, padding=1)
            # resize to target spatial using interpolate (works on MPS and CPU)
            if fmap_smoothed.shape[-2:] != target_spatial:
                fmap_resized = F.interpolate(fmap_smoothed, size=target_spatial,
                                             mode='bilinear', align_corners=False)
            else:
                fmap_resized = fmap_smoothed
            resized.append(fmap_resized)

        patch = torch.cat(resized, dim=1)
        return patch

    def remove_hooks(self):
        for h in self.handles:
            h.remove()
        self.handles = []


In [28]:
import torch.nn.functional as F

def extract_patch_no_hooks(resnet_model, x, target_spatial=None):
    # follow ResNet forward up to layer3, return concat(layer2, layer3) resized to layer2 size
    x = resnet_model.conv1(x)
    x = resnet_model.bn1(x)
    x = resnet_model.relu(x)
    x = resnet_model.maxpool(x)
    x = resnet_model.layer1(x)
    out2 = resnet_model.layer2(x)        # (B,512,H2,W2)
    out3 = resnet_model.layer3(out2)     # (B,1024,H3,W3)

    if target_spatial is None:
        target_spatial = (out2.shape[-2], out2.shape[-1])

    fmap2 = F.avg_pool2d(out2, kernel_size=3, stride=1, padding=1)
    fmap3 = F.avg_pool2d(out3, kernel_size=3, stride=1, padding=1)
    if fmap3.shape[-2:] != target_spatial:
        fmap3 = F.interpolate(fmap3, size=target_spatial, mode='bilinear', align_corners=False)

    patch = torch.cat([fmap2, fmap3], dim=1)  # (B,1536,H2,W2)
    return patch


In [29]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

extractor = ResNetFeatureExtractor(pretrained=True, device=device)
extractor.model.eval()
for p in extractor.model.parameters():
    p.requires_grad = False

# sample batch
x = torch.randn(2,3,224,224).to(device)
patch = extractor(x)   # shape (2,1536, H_ref, W_ref) e.g. (2,1536,28,28)
print(patch.shape)

# when done in notebook
# extractor.remove_hooks()


torch.Size([2, 1536, 28, 28])


### Now I replaced AdaptiveAvgPool2d with F.interpolate because AdaptiveAvgPool2d has an MPS bug when the output size doesn't divide input size. interpolate is robust and supported on MPS.



In [30]:
import torch, os
from pathlib import Path
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
import torch.nn.functional as F
from tqdm.auto import tqdm

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print("Device:", device)

CAEMODEL_DIR = Path("saved_spatial_caes_simple")
CAEMODEL_DIR.mkdir(exist_ok=True)


Device: mps


In [31]:
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN,
                         std=IMAGENET_STD)
])
train_ds = ImageFolder("mvtec_all/train", transform=transform)  # folder containing class-subfolders
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4)

val_ds = ImageFolder("mvtec_all/val", transform=transform)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4)


In [32]:
# train_ds must be ImageFolder("mvtec_all/train", transform=transform) created earlier
from torchvision.datasets import ImageFolder

train_ds = ImageFolder("mvtec_all/train", transform=transform)

category_list = train_ds.classes

def loader_for_category_from_full(train_ds, cat_name, batch_size=8, shuffle=True, num_workers=0):
    idx = train_ds.class_to_idx[cat_name]
    inds = [i for i, (_, l) in enumerate(train_ds.samples) if l == idx]
    subset = Subset(train_ds, inds)
    return DataLoader(subset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


In [33]:
import torch.nn as nn

class FeatCAE(nn.Module):
    """Autoencoder."""

    def __init__(self, in_channels=1000, latent_dim=50, is_bn=True):
        super(FeatCAE, self).__init__()

        layers = []
        layers += [nn.Conv2d(in_channels, (in_channels + 2 * latent_dim) // 2, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=(in_channels + 2 * latent_dim) // 2)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d((in_channels + 2 * latent_dim) // 2, 2 * latent_dim, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=2 * latent_dim)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d(2 * latent_dim, latent_dim, kernel_size=1, stride=1, padding=0)]

        self.encoder = nn.Sequential(*layers)

        # if 1x1 conv to reconstruct the rgb values, we try to learn a linear combination
        # of the features for rgb
        layers = []
        layers += [nn.Conv2d(latent_dim, 2 * latent_dim, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=2 * latent_dim)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d(2 * latent_dim, (in_channels + 2 * latent_dim) // 2, kernel_size=1, stride=1, padding=0)]
        if is_bn:
            layers += [nn.BatchNorm2d(num_features=(in_channels + 2 * latent_dim) // 2)]
        layers += [nn.ReLU()]
        layers += [nn.Conv2d((in_channels + 2 * latent_dim) // 2, in_channels, kernel_size=1, stride=1, padding=0)]
        # layers += [nn.ReLU()]

        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [20]:
def train_cae_for_category_simple(cat_name,
                                  extractor,
                                  train_ds,
                                  in_channels=1536,
                                  latent_dim=100,
                                  epochs=15,
                                  batch_size=8,
                                  lr=1e-3,
                                  device=device,
                                  save_dir=CAEMODEL_DIR):
    print(f"=== Train CAE for {cat_name} ===")
    loader = loader_for_category_from_full(train_ds, cat_name, batch_size=batch_size, shuffle=True, num_workers=0)

    # instantiate CAE and optimizer
    cae = FeatCAE(in_channels=in_channels, latent_dim=latent_dim, is_bn=True).to(device).float()
    optimizer = optim.Adam(cae.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    extractor = extractor.to(device)
    extractor.model.eval()
    for p in extractor.model.parameters():
        p.requires_grad = False

    best_loss = float('inf')
    for epoch in range(1, epochs+1):
        cae.train()
        running = 0.0
        n = 0
        pbar = tqdm(loader, desc=f"{cat_name} Epoch {epoch}/{epochs}", leave=False)
        for imgs, _ in pbar:
            imgs = imgs.to(device)

            # extract features on-the-fly (backbone frozen)
            with torch.no_grad():
                features = extractor(imgs)    # (B,1536,H,W) on device

            # forward CAE
            optimizer.zero_grad()
            outputs = cae(features)          # must be same shape as features
            loss = criterion(outputs, features)
            loss.backward()
            optimizer.step()

            running += loss.item() * imgs.size(0)
            n += imgs.size(0)
            pbar.set_postfix({"batch_loss": loss.item()})

        epoch_loss = running / max(1, n)
        print(f"{cat_name} Epoch {epoch}/{epochs} loss: {epoch_loss:.6f}")

        # save best
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            save_path = Path(save_dir) / f"cae_{cat_name}.pth"
            torch.save(cae.state_dict(), save_path)
            print(f"Saved best CAE -> {save_path} (loss {best_loss:.6f})")

    return cae


In [34]:

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
extractor = ResNetFeatureExtractorFast(pretrained=True, device=device)
extractor.model.eval()
for p in extractor.model.parameters():
    p.requires_grad = False

# smoke test (ensure it captures features and returns the expected shape)
x = torch.randn(1,3,224,224).to(device)
with torch.no_grad():
    patch = extractor(x)
print("extractor patch shape:", patch.shape)  # expect (1,1536, H, W), e.g. (1,1536,28,28)

model = FeatCAE(in_channels=1536, latent_dim=100).to(device)
extractor = extractor.to(device)


extractor patch shape: torch.Size([1, 1536, 28, 28])


In [35]:
# tune these hyperparams as needed
EPOCHS = 15
BATCH_SIZE = 8
LR = 1e-3
LATENT_DIM = 100

for cat in category_list:
    ckpt = CAEMODEL_DIR / f"cae_{cat}.pth"
    if ckpt.exists():
        print("Skipping (exists):", cat)
        continue
    _ = train_cae_for_category_simple(cat,
                                      extractor=extractor,
                                      train_ds=train_ds,
                                      in_channels=1536,
                                      latent_dim=LATENT_DIM,
                                      epochs=EPOCHS,
                                      batch_size=BATCH_SIZE,
                                      lr=LR,
                                      device=device,
                                      save_dir=CAEMODEL_DIR)


=== Train CAE for bottle ===


bottle Epoch 1/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 1/15 loss: 0.256879
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.256879)


bottle Epoch 2/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 2/15 loss: 0.054569
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.054569)


bottle Epoch 3/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 3/15 loss: 0.037513
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.037513)


bottle Epoch 4/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 4/15 loss: 0.031338
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.031338)


bottle Epoch 5/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 5/15 loss: 0.027759
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.027759)


bottle Epoch 6/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 6/15 loss: 0.025407
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.025407)


bottle Epoch 7/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 7/15 loss: 0.023596
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.023596)


bottle Epoch 8/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 8/15 loss: 0.022387
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.022387)


bottle Epoch 9/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 9/15 loss: 0.021342
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.021342)


bottle Epoch 10/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 10/15 loss: 0.020464
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.020464)


bottle Epoch 11/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 11/15 loss: 0.019744
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.019744)


bottle Epoch 12/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 12/15 loss: 0.019180
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.019180)


bottle Epoch 13/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 13/15 loss: 0.018960
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.018960)


bottle Epoch 14/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 14/15 loss: 0.017999
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.017999)


bottle Epoch 15/15:   0%|          | 0/27 [00:00<?, ?it/s]

bottle Epoch 15/15 loss: 0.017502
Saved best CAE -> saved_spatial_caes_simple/cae_bottle.pth (loss 0.017502)
=== Train CAE for cable ===


cable Epoch 1/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 1/15 loss: 0.283744
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.283744)


cable Epoch 2/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 2/15 loss: 0.105250
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.105250)


cable Epoch 3/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 3/15 loss: 0.087300
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.087300)


cable Epoch 4/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 4/15 loss: 0.079240
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.079240)


cable Epoch 5/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 5/15 loss: 0.073777
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.073777)


cable Epoch 6/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 6/15 loss: 0.069618
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.069618)


cable Epoch 7/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 7/15 loss: 0.066440
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.066440)


cable Epoch 8/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 8/15 loss: 0.063304
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.063304)


cable Epoch 9/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 9/15 loss: 0.060637
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.060637)


cable Epoch 10/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 10/15 loss: 0.058390
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.058390)


cable Epoch 11/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 11/15 loss: 0.056327
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.056327)


cable Epoch 12/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 12/15 loss: 0.054778
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.054778)


cable Epoch 13/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 13/15 loss: 0.053254
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.053254)


cable Epoch 14/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 14/15 loss: 0.051854
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.051854)


cable Epoch 15/15:   0%|          | 0/28 [00:00<?, ?it/s]

cable Epoch 15/15 loss: 0.050607
Saved best CAE -> saved_spatial_caes_simple/cae_cable.pth (loss 0.050607)
=== Train CAE for capsule ===


capsule Epoch 1/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 1/15 loss: 0.211970
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.211970)


capsule Epoch 2/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 2/15 loss: 0.047264
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.047264)


capsule Epoch 3/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 3/15 loss: 0.034787
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.034787)


capsule Epoch 4/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 4/15 loss: 0.029997
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.029997)


capsule Epoch 5/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 5/15 loss: 0.027114
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.027114)


capsule Epoch 6/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 6/15 loss: 0.024996
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.024996)


capsule Epoch 7/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 7/15 loss: 0.023639
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.023639)


capsule Epoch 8/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 8/15 loss: 0.022293
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.022293)


capsule Epoch 9/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 9/15 loss: 0.021296
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.021296)


capsule Epoch 10/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 10/15 loss: 0.020469
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.020469)


capsule Epoch 11/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 11/15 loss: 0.019799
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.019799)


capsule Epoch 12/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 12/15 loss: 0.019164
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.019164)


capsule Epoch 13/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 13/15 loss: 0.018480
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.018480)


capsule Epoch 14/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 14/15 loss: 0.017980
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.017980)


capsule Epoch 15/15:   0%|          | 0/28 [00:00<?, ?it/s]

capsule Epoch 15/15 loss: 0.017360
Saved best CAE -> saved_spatial_caes_simple/cae_capsule.pth (loss 0.017360)
=== Train CAE for carpet ===


carpet Epoch 1/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 1/15 loss: 0.173080
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.173080)


carpet Epoch 2/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 2/15 loss: 0.039849
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.039849)


carpet Epoch 3/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 3/15 loss: 0.033867
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.033867)


carpet Epoch 4/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 4/15 loss: 0.030135
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.030135)


carpet Epoch 5/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 5/15 loss: 0.027078
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.027078)


carpet Epoch 6/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 6/15 loss: 0.024738
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.024738)


carpet Epoch 7/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 7/15 loss: 0.022956
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.022956)


carpet Epoch 8/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 8/15 loss: 0.021606
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.021606)


carpet Epoch 9/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 9/15 loss: 0.020276
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.020276)


carpet Epoch 10/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 10/15 loss: 0.019202
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.019202)


carpet Epoch 11/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 11/15 loss: 0.018440
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.018440)


carpet Epoch 12/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 12/15 loss: 0.017465
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.017465)


carpet Epoch 13/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 13/15 loss: 0.017123
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.017123)


carpet Epoch 14/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 14/15 loss: 0.016533
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.016533)


carpet Epoch 15/15:   0%|          | 0/35 [00:00<?, ?it/s]

carpet Epoch 15/15 loss: 0.015801
Saved best CAE -> saved_spatial_caes_simple/cae_carpet.pth (loss 0.015801)
=== Train CAE for grid ===


grid Epoch 1/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 1/15 loss: 0.313719
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.313719)


grid Epoch 2/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 2/15 loss: 0.094051
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.094051)


grid Epoch 3/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 3/15 loss: 0.079419
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.079419)


grid Epoch 4/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 4/15 loss: 0.072378
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.072378)


grid Epoch 5/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 5/15 loss: 0.070556
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.070556)


grid Epoch 6/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 6/15 loss: 0.063087
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.063087)


grid Epoch 7/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 7/15 loss: 0.058716
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.058716)


grid Epoch 8/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 8/15 loss: 0.054937
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.054937)


grid Epoch 9/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 9/15 loss: 0.056559


grid Epoch 10/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 10/15 loss: 0.053738
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.053738)


grid Epoch 11/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 11/15 loss: 0.052283
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.052283)


grid Epoch 12/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 12/15 loss: 0.049906
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.049906)


grid Epoch 13/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 13/15 loss: 0.053134


grid Epoch 14/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 14/15 loss: 0.050933


grid Epoch 15/15:   0%|          | 0/33 [00:00<?, ?it/s]

grid Epoch 15/15 loss: 0.049290
Saved best CAE -> saved_spatial_caes_simple/cae_grid.pth (loss 0.049290)
=== Train CAE for hazelnut ===


hazelnut Epoch 1/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 1/15 loss: 0.173185
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.173185)


hazelnut Epoch 2/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 2/15 loss: 0.071746
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.071746)


hazelnut Epoch 3/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 3/15 loss: 0.061449
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.061449)


hazelnut Epoch 4/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 4/15 loss: 0.055270
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.055270)


hazelnut Epoch 5/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 5/15 loss: 0.051775
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.051775)


hazelnut Epoch 6/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 6/15 loss: 0.048632
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.048632)


hazelnut Epoch 7/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 7/15 loss: 0.045669
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.045669)


hazelnut Epoch 8/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 8/15 loss: 0.044366
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.044366)


hazelnut Epoch 9/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 9/15 loss: 0.042605
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.042605)


hazelnut Epoch 10/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 10/15 loss: 0.040395
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.040395)


hazelnut Epoch 11/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 11/15 loss: 0.039337
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.039337)


hazelnut Epoch 12/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 12/15 loss: 0.038465
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.038465)


hazelnut Epoch 13/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 13/15 loss: 0.037245
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.037245)


hazelnut Epoch 14/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 14/15 loss: 0.037286


hazelnut Epoch 15/15:   0%|          | 0/49 [00:00<?, ?it/s]

hazelnut Epoch 15/15 loss: 0.035446
Saved best CAE -> saved_spatial_caes_simple/cae_hazelnut.pth (loss 0.035446)
=== Train CAE for leather ===


leather Epoch 1/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 1/15 loss: 0.194206
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.194206)


leather Epoch 2/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 2/15 loss: 0.041472
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.041472)


leather Epoch 3/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 3/15 loss: 0.035116
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.035116)


leather Epoch 4/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 4/15 loss: 0.031667
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.031667)


leather Epoch 5/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 5/15 loss: 0.029600
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.029600)


leather Epoch 6/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 6/15 loss: 0.026527
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.026527)


leather Epoch 7/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 7/15 loss: 0.025192
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.025192)


leather Epoch 8/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 8/15 loss: 0.023983
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.023983)


leather Epoch 9/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 9/15 loss: 0.022479
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.022479)


leather Epoch 10/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 10/15 loss: 0.020632
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.020632)


leather Epoch 11/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 11/15 loss: 0.020050
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.020050)


leather Epoch 12/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 12/15 loss: 0.018960
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.018960)


leather Epoch 13/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 13/15 loss: 0.018216
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.018216)


leather Epoch 14/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 14/15 loss: 0.017885
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.017885)


leather Epoch 15/15:   0%|          | 0/31 [00:00<?, ?it/s]

leather Epoch 15/15 loss: 0.017154
Saved best CAE -> saved_spatial_caes_simple/cae_leather.pth (loss 0.017154)
=== Train CAE for metal_nut ===


metal_nut Epoch 1/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 1/15 loss: 0.279422
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.279422)


metal_nut Epoch 2/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 2/15 loss: 0.097280
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.097280)


metal_nut Epoch 3/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 3/15 loss: 0.075574
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.075574)


metal_nut Epoch 4/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 4/15 loss: 0.071639
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.071639)


metal_nut Epoch 5/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 5/15 loss: 0.058399
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.058399)


metal_nut Epoch 6/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 6/15 loss: 0.053446
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.053446)


metal_nut Epoch 7/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 7/15 loss: 0.050257
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.050257)


metal_nut Epoch 8/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 8/15 loss: 0.052768


metal_nut Epoch 9/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 9/15 loss: 0.049355
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.049355)


metal_nut Epoch 10/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 10/15 loss: 0.044457
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.044457)


metal_nut Epoch 11/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 11/15 loss: 0.042775
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.042775)


metal_nut Epoch 12/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 12/15 loss: 0.041470
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.041470)


metal_nut Epoch 13/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 13/15 loss: 0.040310
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.040310)


metal_nut Epoch 14/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 14/15 loss: 0.039027
Saved best CAE -> saved_spatial_caes_simple/cae_metal_nut.pth (loss 0.039027)


metal_nut Epoch 15/15:   0%|          | 0/28 [00:00<?, ?it/s]

metal_nut Epoch 15/15 loss: 0.039059
=== Train CAE for pill ===


pill Epoch 1/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 1/15 loss: 0.196613
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.196613)


pill Epoch 2/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 2/15 loss: 0.056765
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.056765)


pill Epoch 3/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 3/15 loss: 0.045601
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.045601)


pill Epoch 4/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 4/15 loss: 0.040787
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.040787)


pill Epoch 5/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 5/15 loss: 0.037447
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.037447)


pill Epoch 6/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 6/15 loss: 0.035005
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.035005)


pill Epoch 7/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 7/15 loss: 0.033014
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.033014)


pill Epoch 8/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 8/15 loss: 0.031690
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.031690)


pill Epoch 9/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 9/15 loss: 0.030444
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.030444)


pill Epoch 10/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 10/15 loss: 0.029131
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.029131)


pill Epoch 11/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 11/15 loss: 0.028260
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.028260)


pill Epoch 12/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 12/15 loss: 0.027547
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.027547)


pill Epoch 13/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 13/15 loss: 0.026754
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.026754)


pill Epoch 14/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 14/15 loss: 0.026258
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.026258)


pill Epoch 15/15:   0%|          | 0/34 [00:00<?, ?it/s]

pill Epoch 15/15 loss: 0.025614
Saved best CAE -> saved_spatial_caes_simple/cae_pill.pth (loss 0.025614)
=== Train CAE for screw ===


screw Epoch 1/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 1/15 loss: 0.187068
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.187068)


screw Epoch 2/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 2/15 loss: 0.069307
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.069307)


screw Epoch 3/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 3/15 loss: 0.055570
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.055570)


screw Epoch 4/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 4/15 loss: 0.049127
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.049127)


screw Epoch 5/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 5/15 loss: 0.043796
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.043796)


screw Epoch 6/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 6/15 loss: 0.040067
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.040067)


screw Epoch 7/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 7/15 loss: 0.037644
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.037644)


screw Epoch 8/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 8/15 loss: 0.036343
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.036343)


screw Epoch 9/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 9/15 loss: 0.034819
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.034819)


screw Epoch 10/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 10/15 loss: 0.033424
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.033424)


screw Epoch 11/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 11/15 loss: 0.031899
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.031899)


screw Epoch 12/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 12/15 loss: 0.030952
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.030952)


screw Epoch 13/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 13/15 loss: 0.030160
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.030160)


screw Epoch 14/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 14/15 loss: 0.028916
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.028916)


screw Epoch 15/15:   0%|          | 0/40 [00:00<?, ?it/s]

screw Epoch 15/15 loss: 0.028475
Saved best CAE -> saved_spatial_caes_simple/cae_screw.pth (loss 0.028475)
=== Train CAE for tile ===


tile Epoch 1/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 1/15 loss: 0.227025
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.227025)


tile Epoch 2/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 2/15 loss: 0.056627
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.056627)


tile Epoch 3/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 3/15 loss: 0.048743
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.048743)


tile Epoch 4/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 4/15 loss: 0.044964
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.044964)


tile Epoch 5/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 5/15 loss: 0.041778
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.041778)


tile Epoch 6/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 6/15 loss: 0.038572
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.038572)


tile Epoch 7/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 7/15 loss: 0.036226
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.036226)


tile Epoch 8/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 8/15 loss: 0.034209
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.034209)


tile Epoch 9/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 9/15 loss: 0.032620
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.032620)


tile Epoch 10/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 10/15 loss: 0.030891
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.030891)


tile Epoch 11/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 11/15 loss: 0.029872
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.029872)


tile Epoch 12/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 12/15 loss: 0.028538
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.028538)


tile Epoch 13/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 13/15 loss: 0.027938
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.027938)


tile Epoch 14/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 14/15 loss: 0.027173
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.027173)


tile Epoch 15/15:   0%|          | 0/29 [00:00<?, ?it/s]

tile Epoch 15/15 loss: 0.026029
Saved best CAE -> saved_spatial_caes_simple/cae_tile.pth (loss 0.026029)
=== Train CAE for toothbrush ===


toothbrush Epoch 1/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 1/15 loss: 0.498087
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.498087)


toothbrush Epoch 2/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 2/15 loss: 0.163153
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.163153)


toothbrush Epoch 3/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 3/15 loss: 0.102591
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.102591)


toothbrush Epoch 4/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 4/15 loss: 0.078584
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.078584)


toothbrush Epoch 5/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 5/15 loss: 0.067156
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.067156)


toothbrush Epoch 6/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 6/15 loss: 0.059104
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.059104)


toothbrush Epoch 7/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 7/15 loss: 0.054193
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.054193)


toothbrush Epoch 8/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 8/15 loss: 0.050230
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.050230)


toothbrush Epoch 9/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 9/15 loss: 0.047567
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.047567)


toothbrush Epoch 10/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 10/15 loss: 0.046225
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.046225)


toothbrush Epoch 11/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 11/15 loss: 0.043814
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.043814)


toothbrush Epoch 12/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 12/15 loss: 0.042567
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.042567)


toothbrush Epoch 13/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 13/15 loss: 0.040949
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.040949)


toothbrush Epoch 14/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 14/15 loss: 0.040785
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.040785)


toothbrush Epoch 15/15:   0%|          | 0/8 [00:00<?, ?it/s]

toothbrush Epoch 15/15 loss: 0.039837
Saved best CAE -> saved_spatial_caes_simple/cae_toothbrush.pth (loss 0.039837)
=== Train CAE for transistor ===


transistor Epoch 1/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 1/15 loss: 0.270341
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.270341)


transistor Epoch 2/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 2/15 loss: 0.083721
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.083721)


transistor Epoch 3/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 3/15 loss: 0.068064
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.068064)


transistor Epoch 4/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 4/15 loss: 0.061469
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.061469)


transistor Epoch 5/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 5/15 loss: 0.056462
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.056462)


transistor Epoch 6/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 6/15 loss: 0.053153
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.053153)


transistor Epoch 7/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 7/15 loss: 0.050139
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.050139)


transistor Epoch 8/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 8/15 loss: 0.049371
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.049371)


transistor Epoch 9/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 9/15 loss: 0.046101
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.046101)


transistor Epoch 10/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 10/15 loss: 0.044277
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.044277)


transistor Epoch 11/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 11/15 loss: 0.043035
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.043035)


transistor Epoch 12/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 12/15 loss: 0.041618
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.041618)


transistor Epoch 13/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 13/15 loss: 0.040356
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.040356)


transistor Epoch 14/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 14/15 loss: 0.039586
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.039586)


transistor Epoch 15/15:   0%|          | 0/27 [00:00<?, ?it/s]

transistor Epoch 15/15 loss: 0.038565
Saved best CAE -> saved_spatial_caes_simple/cae_transistor.pth (loss 0.038565)
=== Train CAE for wood ===


wood Epoch 1/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 1/15 loss: 0.186548
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.186548)


wood Epoch 2/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 2/15 loss: 0.057098
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.057098)


wood Epoch 3/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 3/15 loss: 0.049458
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.049458)


wood Epoch 4/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 4/15 loss: 0.044459
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.044459)


wood Epoch 5/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 5/15 loss: 0.041003
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.041003)


wood Epoch 6/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 6/15 loss: 0.038387
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.038387)


wood Epoch 7/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 7/15 loss: 0.035735
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.035735)


wood Epoch 8/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 8/15 loss: 0.033226
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.033226)


wood Epoch 9/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 9/15 loss: 0.033106
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.033106)


wood Epoch 10/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 10/15 loss: 0.031191
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.031191)


wood Epoch 11/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 11/15 loss: 0.029710
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.029710)


wood Epoch 12/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 12/15 loss: 0.029303
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.029303)


wood Epoch 13/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 13/15 loss: 0.028333
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.028333)


wood Epoch 14/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 14/15 loss: 0.027478
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.027478)


wood Epoch 15/15:   0%|          | 0/31 [00:00<?, ?it/s]

wood Epoch 15/15 loss: 0.026952
Saved best CAE -> saved_spatial_caes_simple/cae_wood.pth (loss 0.026952)
=== Train CAE for zipper ===


zipper Epoch 1/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 1/15 loss: 0.202320
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.202320)


zipper Epoch 2/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 2/15 loss: 0.046141
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.046141)


zipper Epoch 3/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 3/15 loss: 0.036344
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.036344)


zipper Epoch 4/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 4/15 loss: 0.033378
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.033378)


zipper Epoch 5/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 5/15 loss: 0.031214
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.031214)


zipper Epoch 6/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 6/15 loss: 0.029098
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.029098)


zipper Epoch 7/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 7/15 loss: 0.027125
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.027125)


zipper Epoch 8/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 8/15 loss: 0.025438
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.025438)


zipper Epoch 9/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 9/15 loss: 0.023695
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.023695)


zipper Epoch 10/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 10/15 loss: 0.022500
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.022500)


zipper Epoch 11/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 11/15 loss: 0.021653
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.021653)


zipper Epoch 12/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 12/15 loss: 0.020782
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.020782)


zipper Epoch 13/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 13/15 loss: 0.020063
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.020063)


zipper Epoch 14/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 14/15 loss: 0.019352
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.019352)


zipper Epoch 15/15:   0%|          | 0/30 [00:00<?, ?it/s]

zipper Epoch 15/15 loss: 0.018721
Saved best CAE -> saved_spatial_caes_simple/cae_zipper.pth (loss 0.018721)


In [36]:
# 1) confirm extractor exists and has handles registered
print("extractor exists:", 'extractor' in globals())
print("handles attr present?", hasattr(extractor, "handles"))
print("num handles:", len(getattr(extractor, "handles", [])))
print("handles objects:", getattr(extractor, "handles", []))

# 2) show the modules we expect hooks on
print("layer2[-1]:", extractor.model.layer2[-1])
print("layer3[-1]:", extractor.model.layer3[-1])

# 3) quick forward test (random input)
x = torch.randn(1,3,224,224).to(device)
extractor.features = []    # clear
with torch.no_grad():
    _ = extractor(x)
print("features captured:", len(extractor.features))
if len(extractor.features)>0:
    for i,f in enumerate(extractor.features):
        print(i, f.shape, f.device)


extractor exists: True
handles attr present? True
num handles: 2
handles objects: [<torch.utils.hooks.RemovableHandle object at 0x308243250>, <torch.utils.hooks.RemovableHandle object at 0x30857fc50>]
layer2[-1]: Bottleneck(
  (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)
layer3[-1]: Bottleneck(
  (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(256, 256, kernel_size=(3,

In [37]:
trained_caes = {}
for cat in category_list:
    p = CAEMODEL_DIR / f"cae_{cat}.pth"
    if not p.exists():
        print("Missing CAE for", cat); continue
    cae = FeatCAE(in_channels=1536, latent_dim=LATENT_DIM, is_bn=True)
    cae.load_state_dict(torch.load(p, map_location=device))
    cae.to(device).eval()
    trained_caes[cat] = cae
print("Loaded CAEs:", list(trained_caes.keys()))


Loaded CAEs: ['bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut', 'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush', 'transistor', 'wood', 'zipper']
