In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from sklearn.model_selection import train_test_split


from PIL import Image
import numpy as np

import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [8]:
# autoencoder code borrowed from patrickloeber: https://github.com/patrickloeber/pytorch-examples/blob/master/Autoencoder.ipynb

class Conv_AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__() 
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, 4, stride = 3, padding = 1), # 32, 16, 67, 67
            nn.ReLU(),
            nn.Conv2d(8, 16, 3, stride = 2, padding = 1),# 32, 32, 34, 34
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride = 2, padding = 1), # 32, 32, 17, 17
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 17 * 17, 3000),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(3000, 32 * 17 * 17),
            nn.ReLU(),
            nn.Unflatten(dim = 1, unflattened_size=(32, 17, 17)),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride = 2, padding = 1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 8, 3, stride = 2, padding = 1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, 4, stride = 3, padding = 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [5]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, file_list, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.file_list = file_list

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_name = self.file_list[0]
        img_name = os.path.join(self.root_dir, img_name)
        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        return image

# Define transformations to apply to your images
transform = transforms.Compose([
    # transforms.Resize((200, 200)), # not necessary, since images are already 200x200
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])


In [6]:
# perform train test split
# iterate through all subfolders (rex, tri, etc)
current_dir = os.getcwd()
project_dir = os.path.dirname(current_dir)
images_dir = os.path.join(project_dir, 'augmented_images')

file_list = []

for folder in os.listdir(images_dir):
    if folder == 'mis':
        continue
    folder_fp = os.path.join(images_dir, folder) 
    for img in os.listdir(folder_fp):
        file_list.append(os.path.join(folder_fp, img))

train_indices, val_indices = train_test_split(range(len(file_list)), test_size=0.2, random_state=42)
train_files = [file_list[i] for i in train_indices]
validation_files = [file_list[i] for i in val_indices]

In [21]:
len(validation_files)

1901

In [24]:
current_dir = os.getcwd()
project_dir = os.path.dirname(current_dir)
images_dir = os.path.join(project_dir, 'augmented_images')

validation = CustomDataset(file_list = validation_files, root_dir = images_dir, transform=transform)

# Create a dataloader
batch_size = 32
validation_loader = DataLoader(validation, batch_size=batch_size, shuffle=True)

misformed_images_dir = os.path.join(project_dir, 'augmented_images')

mis_dir = os.path.join(misformed_images_dir, 'mis')
mis_files = []

for img in os.listdir(mis_dir):
    mis_files.append(os.path.join(mis_dir, img))


misformed = CustomDataset(file_list = mis_files, root_dir = images_dir, transform=transform)
misformed_loader = DataLoader(misformed, batch_size=batch_size, shuffle=False)

In [25]:
model = Conv_AutoEncoder()
model.load_state_dict(torch.load('C:\\Users\\So\\Documents\\code\\datahacks_2024\\classification\\validated_conv_autoencoder_50_epochs.pth'))
model.to(device)

Conv_AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 8, kernel_size=(4, 4), stride=(3, 3), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=9248, out_features=3000, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=3000, out_features=9248, bias=True)
    (1): ReLU()
    (2): Unflatten(dim=1, unflattened_size=(32, 17, 17))
    (3): ReLU()
    (4): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): ReLU()
    (6): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
    (8): ConvTranspose2d(8, 1, kernel_size=(4, 4), stride=(3, 3), padding=(1, 1))
    (9): Sigmoid()
  )
)

In [35]:
test = {
    'og': [],
    'recon': []
}

anomaly_test = {
    'og': [],
    'recon': []
}

for img in validation_loader:
    img = img.to(device)
    # img = img.reshape(-1, 200*200)
    recon = model(img)
    test['og'].append(img)
    test['recon'].append(recon)


for img in misformed_loader:
    img = img.to(device)
    # img = img.reshape(-1, 200*200)
    recon = model(img)
    anomaly_test['og'].append(img)
    anomaly_test['recon'].append(recon)

In [36]:
normal_mse = []
for batch_n in range(len(test['og'])):
    for i in range(len(test['og'][batch_n])):
        og = test['og'][batch_n][i].squeeze()
        recon = test['recon'][batch_n][i].squeeze()
        normal_mse.append(F.mse_loss(og, recon).item())

disformed_mse = []
for batch_n in range(len(anomaly_test['og'])):
    for i in range(len(anomaly_test['og'][batch_n])):
        og = anomaly_test['og'][batch_n][i].squeeze()
        recon = anomaly_test['recon'][batch_n][i].squeeze()
        disformed_mse.append(F.mse_loss(og, recon).item())

normal_mse, disformed_mse = np.array(normal_mse), np.array(disformed_mse)

In [39]:
normal_mse.std(), normal_mse.mean()

(7.552036462675999e-12, 0.0028299184050410986)

In [None]:
disformed_mse.std(), disformed_mse.mean()