In [1]:
import torch
import torchvision

In [None]:
class auto_encoder(torch.nn.Module):
    def __init__(self, input_size, bottleneck):
        super(auto_encoder, self).__init__()
        layers = [input_size] + bottleneck
        encoder_layers = []
        decoder_layers = []
        
        for i in range(len(layers)-1):
            encoder_layers.append(torch.nn.Linear(layers[i], layers[i+1]))
            encoder_layers.append(torch.nn.ReLU())
            
        for i in range(len(layers)-1, 0, -1):
            decoder_layers.append(torch.nn.Linear(layers[i], layers[i-1]))
            if i > 1:  # No ReLU after last layer
                decoder_layers.append(torch.nn.ReLU())
                
        self.encoder = torch.nn.Sequential(*encoder_layers)
        self.decoder = torch.nn.Sequential(*decoder_layers)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


dataset = torchvision.datasets.MNIST(root='./', train=True, download=True, transform=torchvision.transforms.ToTensor())    
full_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
only_number_to_train = 4
dataset.data = dataset.data[dataset.targets==only_number_to_train]
dataset.targets = dataset.targets[dataset.targets==only_number_to_train]
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

from tqdm import tqdm 

bottlenecks = [128,64,32]
num_epochs=10
loss_fn = torch.nn.MSELoss()

for bottleneck in bottlenecks:
    model = auto_encoder(28*28, [bottleneck])
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    with tqdm(total=num_epochs*(len(train_dataloader))) as pbar:
        for epoch in range(num_epochs):
            for data in train_dataloader:
                optimizer.zero_grad()
                img, _ = data
                img = img.view(img.size(0), -1)
                xhat = model(img)
                loss = loss_fn(xhat, img)
                loss.backward()
                optimizer.step()
                pbar.update(1)
                pbar.set_description(f"Training Loss: {loss.item()}")
    torch.save(model.state_dict(), f"autoencoder_{bottleneck}.pt")

    losses = []
    numberwise_losses = [[] for _ in range(10)]

    with torch.no_grad():
        for data in full_dataloader:
            img, number = data
            img = img.view(img.size(0), -1)
            xhat = model(img)
            batch_loss = loss_fn(xhat, img).item()
            losses.append(batch_loss)
 
            for idx, n in enumerate(number):
                numberwise_losses[n.item()].append(batch_loss)

    mean_loss = torch.tensor(losses).mean().item()
    std_loss = torch.tensor(losses).std().item()

    digit_means = [torch.tensor(d_losses).mean().item() if d_losses else float('inf') 
                  for d_losses in numberwise_losses]
    ## basically choose the 99% confidence interval
    trained_digit_mean = digit_means[only_number_to_train]
    threshold = trained_digit_mean + 2 * torch.tensor(numberwise_losses[only_number_to_train]).std().item()
    
    print(f"\nBottleneck size: {bottleneck}")
    print(f"Overall mean loss: {mean_loss:.4f} ± {std_loss:.4f}")
    print(f"\nMean loss per digit:")
    for digit, mean in enumerate(digit_means):
        print(f"Digit {digit}: {mean:.4f} {'(trained)' if digit==only_number_to_train else ''}")
    
    print(f"\nDetection threshold: {threshold:.4f}")
    print("\nClassification results:")

    print("Running On Means")
    for digit, mean in enumerate(digit_means):
        detected = mean <= threshold
        print(f"Digit {digit}: {'Detected as trained class' if detected else 'Different class'}")
    
    print("Running On rest of the data")
    for image,digit in full_dataloader:
        img = image.view(image.size(0), -1)
        xhat = model(img)
        batch_loss = loss_fn(xhat, img).item()
        detected = batch_loss <= threshold
        print(f"Digit {digit.item()}: {'Detected as trained class' if detected else 'Different class'}")






  0%|          | 0/920 [00:00<?, ?it/s]

Training Loss: 0.007173612248152494: 100%|██████████| 920/920 [00:10<00:00, 85.59it/s] 



Bottleneck size: 128
Overall mean loss: 0.0071 ± 0.0003

Mean loss per digit:
Digit 0: inf 
Digit 1: inf 
Digit 2: inf 
Digit 3: inf 
Digit 4: 0.0071 (trained)
Digit 5: inf 
Digit 6: inf 
Digit 7: inf 
Digit 8: inf 
Digit 9: inf 

Detection threshold: 0.0078

Classification results:
Digit 0: Different class
Digit 1: Different class
Digit 2: Different class
Digit 3: Different class
Digit 4: Detected as trained class
Digit 5: Different class
Digit 6: Different class
Digit 7: Different class
Digit 8: Different class
Digit 9: Different class


Training Loss: 0.013123828917741776: 100%|██████████| 920/920 [00:11<00:00, 80.85it/s]



Bottleneck size: 64
Overall mean loss: 0.0143 ± 0.0007

Mean loss per digit:
Digit 0: inf 
Digit 1: inf 
Digit 2: inf 
Digit 3: inf 
Digit 4: 0.0143 (trained)
Digit 5: inf 
Digit 6: inf 
Digit 7: inf 
Digit 8: inf 
Digit 9: inf 

Detection threshold: 0.0158

Classification results:
Digit 0: Different class
Digit 1: Different class
Digit 2: Different class
Digit 3: Different class
Digit 4: Detected as trained class
Digit 5: Different class
Digit 6: Different class
Digit 7: Different class
Digit 8: Different class
Digit 9: Different class


Training Loss: 0.025942740961909294: 100%|██████████| 920/920 [00:11<00:00, 83.23it/s]



Bottleneck size: 32
Overall mean loss: 0.0244 ± 0.0011

Mean loss per digit:
Digit 0: inf 
Digit 1: inf 
Digit 2: inf 
Digit 3: inf 
Digit 4: 0.0244 (trained)
Digit 5: inf 
Digit 6: inf 
Digit 7: inf 
Digit 8: inf 
Digit 9: inf 

Detection threshold: 0.0265

Classification results:
Digit 0: Different class
Digit 1: Different class
Digit 2: Different class
Digit 3: Different class
Digit 4: Detected as trained class
Digit 5: Different class
Digit 6: Different class
Digit 7: Different class
Digit 8: Different class
Digit 9: Different class
