In [None]:
# === Autoencoder for Defect Detection ===
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


In [None]:

# Simulate "defect" dataset (replace with real image dataset when available)
np.random.seed(42)
normal_imgs = np.random.normal(0.5, 0.1, (950, 100))   # Normal images
defect_imgs = np.random.uniform(0, 1, (50, 100))       # Defective (random noise)
all_imgs = np.vstack([normal_imgs, defect_imgs])
X = torch.FloatTensor(all_imgs)

print("Dataset shape:", X.shape) # (total number of images, features per image)

In [None]:

# Define simple Autoencoder
class AutoEncoder(nn.Module):
    def __init__(self, input_dim=100, encoding_dim=10):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, encoding_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [None]:

# Initialize model
model = AutoEncoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# === Train only on normal images ===
for epoch in range(100):
    optimizer.zero_grad()
    output = model(X[:950])  # Train on normal samples only
    loss = criterion(output, X[:950])
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/100], Loss: {loss.item():.4f}")


In [None]:

# === Reconstruction and Anomaly Detection ===
with torch.no_grad():
    reconstructed = model(X)
    mse = torch.mean((X - reconstructed) ** 2, dim=1)

# Use 95th percentile as anomaly threshold
threshold = np.percentile(mse.numpy(), 95)
anomalies = mse.numpy() > threshold

# === Plot Reconstruction Errors ===
plt.figure(figsize=(10, 6))
plt.plot(mse.numpy(), label="Reconstruction Error")
plt.axhline(threshold, color='r', linestyle='--', label="Threshold (95th percentile)")
plt.title("Autoencoder: Image Anomaly Detection")
plt.xlabel("Sample Index")
plt.ylabel("MSE (Reconstruction Error)")
plt.legend()
plt.grid()
plt.show()

print(f"Detected anomalies: {np.sum(anomalies)} out of {len(X)} samples")