In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.datasets as datasets
import torchvision.utils as vutils
import torch.optim as optim
import numpy as np
from PIL import Image
import os
import random
import pandas as pd
from torchvision.transforms import RandomRotation, RandomHorizontalFlip, RandomVerticalFlip, RandomResizedCrop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:

# Define Generator
class Generator(nn.Module):
    def __init__(self,nz,ngf,nc):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)  # This may output [B, 1, 11, 11] for 224x224 inputs
        )
        
        # Adaptive pooling to force the spatial dimensions to 1x1
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        x = self.main(input)
        # x will be [B, 1, H, W] (e.g., [16, 1, 11, 11] for 224x224 images)
        x = self.pool(x)  # Now x is [B, 1, 1, 1]
        x = x.view(x.size(0), -1)  # Flatten to [B, 1]
        return self.sigmoid(x)





In [3]:

class TrainDataset(Dataset):
    def __init__(self, images_root, image_files, transform=None, augmentations=None):
        self.image_files = image_files
        self.images_root = images_root
        self.transform = transform
        self.augmentations = augmentations

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(os.path.join(self.images_root, img_path)).convert('RGB')

        if self.augmentations:
            image = self.augmentations(image)
        
        if self.transform:
            image = self.transform(image)
            

        return image

    
class TestDataset(Dataset):
    def __init__(self, images_root, image_files, transform=None, augmentations=None):
        self.image_files = image_files
        self.images_root = images_root
        self.transform = transform
        self.augmentations = augmentations


    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(os.path.join(self.images_root, img_path)).convert('RGB')
        
        if self.transform:
            image = self.transform(image)

        # Determine the label based on the filename or directory structure
        if 'worm' in self.image_files[idx].lower():  # assuming 'worm' indicates worm-cropped images
            label = 1  # worm present
        else:
            label = 0  # no worm
        
        
        return image, label


In [4]:
# Example vae_loss function (you can replace this with your own)
def vae_loss(reconstructed, original, mu, logvar):
    """
    Compute the VAE loss.
      - Reconstruction loss: binary cross-entropy between reconstructed and original image.
      - KL divergence loss: forcing the latent space distribution closer to N(0,1)
    Returns a tuple of (total_loss, recon_loss, kl_loss)
    """
    bce = nn.functional.binary_cross_entropy(reconstructed, original, reduction='mean')
    kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = bce + kl
    return total_loss, bce, kl

def vae_loss_inference(reconstructed, original, mu, logvar):
    recon_loss = nn.functional.mse_loss(reconstructed, original, reduction="none")
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div



In [5]:
# Hyperparameters
nz = 100  # Latent vector size
ngf = 224  # Generator feature map size
ndf = 224  # Discriminator feature map size
nc = 3    # Number of channels (1 for grayscale, 3 for RGB)
batch_size = 16
epochs = 150
lr = 0.0002
beta1 = 0.5  # Adam optimizer beta1
early_stop_patient = 150


# Data loading
data_root = "./k-fold/fold_1"
splits = ["train", "eval", "test"]

data_dict = {}
datasets_dict = {}


data_transforms = transforms.Compose([
    RandomRotation(degrees=30),  # Randomly rotate images by up to 30 degrees
    RandomHorizontalFlip(p=0.5),  # Randomly flip images horizontally with a probability of 0.5
    RandomVerticalFlip(p=0.5),  # Randomly flip images vertically with a probability of 0.5
    RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0))  # Randomly crop and resize images to 128x128 with a scale between 80% and 100% of the original size
])


# Load CSV files and create datasets
for split in splits:
    try:
        # Load CSV data
        csv_path =  f"{data_root}/{split}/{split}.csv"
        data_dict[split] = pd.read_csv(csv_path)
        # print(data_dict["train"]['filename'].values)
        # Create dataset for each split
        datasets_dict[split] = TestDataset(
            images_root="./data/",
            image_files=data_dict[split]['filename'].values,
            transform=transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor()
            ]),
            augmentations=data_transforms if split == "train" else None
        )
    except Exception as e:
        print(f"Error loading {split} split: {e}")

# Create dataloaders
dataloaders = {
    split: DataLoader(
        datasets_dict[split],
        batch_size=batch_size if split == "train" else 1,
        shuffle=(split == "train"),
        num_workers=4,
        pin_memory=True
    ) for split in splits
}

print("Available splits:", list(datasets_dict.keys()))
print("Dataset sizes:", {split: len(dataset) for split, dataset in datasets_dict.items()})


Available splits: ['train', 'eval', 'test']
Dataset sizes: {'train': 122, 'eval': 16, 'test': 35}


In [6]:

# Initialize models
generator = Generator(nz,ngf,nc).to(device)
discriminator = Discriminator(nc, ndf).to(device)

generator.apply(lambda m: nn.init.normal_(m.weight, 0, 0.02) if hasattr(m, 'weight') else None)
discriminator.apply(lambda m: nn.init.normal_(m.weight, 0, 0.02) if hasattr(m, 'weight') else None)

# Loss and optimizer
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop
for epoch in range(epochs):
    for real_images,_ in dataloaders["train"]:
        real_images = real_images.to(device)
        batch_size = real_images.size(0)
        # print(real_images.shape)
        # Train Discriminator
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake_images = generator(noise)
        
        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)
        
        optimizerD.zero_grad()
        output_real = discriminator(real_images)
        # print(output_real.shape)
        # print(real_labels.shape)
        output_fake = discriminator(fake_images.detach())
        # print(output_fake.shape)
        # print(fake_labels.shape)

        lossD = criterion(output_real, real_labels) + criterion(output_fake, fake_labels)
        lossD.backward()
        optimizerD.step()
        
        # Train Generator
        optimizerG.zero_grad()
        output_fake = discriminator(fake_images)
        lossG = criterion(output_fake, real_labels)
        lossG.backward()
        optimizerG.step()
        
    print(f"Epoch [{epoch+1}/{epochs}] Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")
    
    # Save sample images
    if epoch % 10 == 0:
        vutils.save_image(fake_images, f"reconstruction_images/epoch_{epoch}.png", normalize=True)
        torch.save(generator.state_dict(), "./models/gen.pth")
        torch.save(discriminator.state_dict(), "./models/disc.pth")



print("Training Complete!")


Epoch [1/150] Loss D: 0.7335, Loss G: 2.1064
Epoch [2/150] Loss D: 0.3748, Loss G: 3.0677
Epoch [3/150] Loss D: 0.2093, Loss G: 3.4999
Epoch [4/150] Loss D: 0.1251, Loss G: 3.8794
Epoch [5/150] Loss D: 0.0823, Loss G: 4.1670
Epoch [6/150] Loss D: 0.0575, Loss G: 4.4120
Epoch [7/150] Loss D: 0.0411, Loss G: 4.7333
Epoch [8/150] Loss D: 0.0322, Loss G: 4.8823
Epoch [9/150] Loss D: 0.0255, Loss G: 5.0728
Epoch [10/150] Loss D: 0.0211, Loss G: 5.2391
Epoch [11/150] Loss D: 0.0168, Loss G: 5.4758
Epoch [12/150] Loss D: 0.0149, Loss G: 5.5459
Epoch [13/150] Loss D: 0.0125, Loss G: 5.7336
Epoch [14/150] Loss D: 0.0107, Loss G: 5.9032
Epoch [15/150] Loss D: 0.0095, Loss G: 6.0279
Epoch [16/150] Loss D: 0.0081, Loss G: 6.1361
Epoch [17/150] Loss D: 0.0072, Loss G: 6.2284
Epoch [18/150] Loss D: 0.0068, Loss G: 6.1869
Epoch [19/150] Loss D: 0.0059, Loss G: 6.3526
Epoch [20/150] Loss D: 0.0055, Loss G: 6.4088
Epoch [21/150] Loss D: 0.0049, Loss G: 6.5244
Epoch [22/150] Loss D: 0.0044, Loss G: 6.64

In [60]:
def detect_anomaly_threshold(gen, disc, test_data_loader, device):
    gen.eval()
    disc.eval()
    anomaly_scores = []

    with torch.no_grad():
        for real_images,_ in test_data_loader:  
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            # Generate fake images
            noise = torch.randn(batch_size, nz, 1, 1, device=device)  # Assuming `nz` is the latent space size
            fake_images = gen(noise)

            # Get discriminator predictions
            real_preds = disc(real_images)  # Should be close to 1 for real images
            fake_preds = disc(fake_images)  # Should be close to 0 for fake images
            
            # Compute anomaly score as the difference between real and fake scores
            anomaly_score = torch.abs(real_preds - fake_preds)  # The more different, the more anomalous
            anomaly_scores.extend(anomaly_score.cpu().numpy())  # Store anomaly scores

    # Convert scores to tensor for statistical calculations
    anomaly_scores = torch.tensor(anomaly_scores, device=device)

    # Compute threshold (mean + std/2)
    mean_score = anomaly_scores.mean().item()
    std_score = anomaly_scores.std().item()
    anomaly_threshold = mean_score + std_score


    # Log information
    print(f"Mean Anomaly Score: {mean_score}, Std: {std_score}")
    print(f"Anomaly Threshold: {anomaly_threshold}")

    return anomaly_threshold


In [65]:
def test_gan(gen, disc, test_data_loader, anomaly_threshold, device):
    gen.eval()  
    disc.eval()  

    anomalies = []
    true_labels = []  
    pred_labels = []

    with torch.no_grad():
        for real_images, labels in test_data_loader:
            real_images = real_images.to(device)
            true_labels.extend(labels.cpu().numpy())  

            batch_size = real_images.size(0)

            # Generate fake images
            noise = torch.randn(batch_size, nz, 1, 1, device=device)  # Assuming `nz` is latent size
            fake_images = gen(noise)

            # Discriminator scores
            real_scores = disc(real_images).squeeze()  # Get scores for real images
            fake_scores = disc(fake_images).squeeze()  # Get scores for generated images

            # Compute anomaly score as absolute difference
            anomaly_scores = torch.abs(real_scores - fake_scores)
            print(anomaly_scores)

            # Classify images based on anomaly threshold
            if anomaly_scores >anomaly_threshold:
                anomalies.append(("no tericho", anomaly_scores))  # Anomalous (not tericho)
                pred_labels.append(0)
            else:
                anomalies.append(("tericho",anomaly_scores))  # Normal (tericho)
                pred_labels.append(1)

    return real_images, fake_images, anomalies, true_labels, pred_labels  


In [62]:
th = detect_anomaly_threshold(gen= generator, disc = discriminator,test_data_loader=dataloaders['eval'],device=device)
print(th)

Mean Anomaly Score: 0.9522849917411804, Std: 0.0013087138067930937
Anomaly Threshold: 0.9535937055479735
0.9535937055479735


In [None]:
images, reconstructed ,anomalies, true_labels, pred_labels = test_gan(gen= generator, disc = discriminator , test_data_loader=dataloaders['test'] , anomaly_threshold=th ,device=device)

tensor(0.9527, device='cuda:0')
tensor(0.9556, device='cuda:0')
tensor(0.9549, device='cuda:0')
tensor(0.9529, device='cuda:0')
tensor(0.9564, device='cuda:0')
tensor(0.9568, device='cuda:0')
tensor(0.9565, device='cuda:0')
tensor(0.9563, device='cuda:0')
tensor(0.9554, device='cuda:0')
tensor(0.9555, device='cuda:0')
tensor(0.9529, device='cuda:0')
tensor(0.9529, device='cuda:0')
tensor(0.9534, device='cuda:0')
tensor(0.9510, device='cuda:0')
tensor(0.9539, device='cuda:0')
tensor(0.9518, device='cuda:0')
tensor(0.9539, device='cuda:0')
tensor(0.9531, device='cuda:0')
tensor(0.9524, device='cuda:0')
tensor(0.9515, device='cuda:0')
tensor(0.9521, device='cuda:0')
tensor(0.9521, device='cuda:0')
tensor(0.9528, device='cuda:0')
tensor(0.9523, device='cuda:0')
tensor(0.9544, device='cuda:0')
tensor(0.9523, device='cuda:0')
tensor(0.9519, device='cuda:0')
tensor(0.9519, device='cuda:0')
tensor(0.9532, device='cuda:0')
tensor(0.9528, device='cuda:0')
tensor(0.9536, device='cuda:0')
tensor(0

In [49]:
from sklearn.metrics import classification_report, confusion_matrix , f1_score
f1 = f1_score(true_labels, pred_labels)
print(f1)
# print(classification_report(true_labels, pred_labels))
# print(confusion_matrix(true_labels, pred_labels))

0.8888888888888888


In [50]:
print(pred_labels)
print(true_labels)

[1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1)]


In [45]:
print(anomalies[:8])

[('wormed', tensor(0.9527, device='cuda:0')), ('no wormed', tensor(0.9556, device='cuda:0')), ('no wormed', tensor(0.9549, device='cuda:0')), ('wormed', tensor(0.9529, device='cuda:0')), ('no wormed', tensor(0.9564, device='cuda:0')), ('no wormed', tensor(0.9568, device='cuda:0')), ('no wormed', tensor(0.9565, device='cuda:0')), ('no wormed', tensor(0.9563, device='cuda:0'))]


In [20]:
print(anomalies[8:])


[('wormed', 158.14205932617188), ('wormed', 178.74172973632812), ('wormed', 207.54083251953125), ('wormed', 196.41542053222656), ('wormed', 223.1368408203125), ('no wormed', 247.09848022460938), ('wormed', 217.18492126464844), ('wormed', 192.02735900878906), ('wormed', 186.51715087890625), ('wormed', 219.11318969726562), ('wormed', 223.01605224609375), ('no wormed', 253.3943634033203), ('wormed', 230.77816772460938), ('wormed', 207.63681030273438), ('no wormed', 343.3358154296875), ('wormed', 187.82064819335938), ('no wormed', 387.0472106933594), ('no wormed', 355.0072021484375), ('no wormed', 351.0190124511719), ('no wormed', 369.83599853515625), ('wormed', 193.2080535888672), ('wormed', 210.8118438720703), ('no wormed', 282.4904479980469), ('wormed', 227.77255249023438), ('wormed', 198.7176055908203), ('wormed', 193.2436981201172), ('wormed', 219.24148559570312)]


In [53]:
# ==============
# This cell is for testing some posetive and negative samples the repeated code here is for loading model 
# ==============


# Hyperparameters
nz = 100  # Latent vector size
ngf = 224  # Generator feature map size
ndf = 224  # Discriminator feature map size
nc = 3    # Number of channels (1 for grayscale, 3 for RGB)
batch_size = 16
epochs = 150
lr = 0.0002
beta1 = 0.5  # Adam optimizer beta1
early_stop_patient = 150


# Data loading
data_root = "./test"
splits = ["test"]

data_dict = {}
datasets_dict = {}


data_transforms = transforms.Compose([
    RandomRotation(degrees=30),  # Randomly rotate images by up to 30 degrees
    RandomHorizontalFlip(p=0.5),  # Randomly flip images horizontally with a probability of 0.5
    RandomVerticalFlip(p=0.5),  # Randomly flip images vertically with a probability of 0.5
    RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0))  # Randomly crop and resize images to 128x128 with a scale between 80% and 100% of the original size
])


# Load CSV files and create datasets
for split in splits:
    try:
        # Load CSV data
        csv_path =  f"{data_root}/test.csv"
        data_dict[split] = pd.read_csv(csv_path)
        # print(data_dict["train"]['filename'].values)
        # Create dataset for each split
        datasets_dict[split] = TestDataset(
            images_root="./test/",
            image_files=data_dict[split]['filename'].values,
            transform=transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor()
            ]),
            augmentations=data_transforms if split == "train" else None
        )
    except Exception as e:
        print(f"Error loading {split} split: {e}")

# Create dataloaders
test_data_loader = DataLoader(
    datasets_dict["test"],
    batch_size=1,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Available splits:", list(datasets_dict.keys()))
print("Dataset sizes:", {split: len(dataset) for split, dataset in datasets_dict.items()})


Available splits: ['test']
Dataset sizes: {'test': 17}


In [66]:
images, reconstructed ,anomalies, true_labels, pred_labels = test_gan(gen= generator, disc = discriminator , test_data_loader=test_data_loader, anomaly_threshold=th ,device=device)
f1 = f1_score(true_labels, pred_labels)
print(f1)

tensor(0.9518, device='cuda:0')
tensor(0.9549, device='cuda:0')
tensor(0.9533, device='cuda:0')
tensor(0.9528, device='cuda:0')
tensor(0.9522, device='cuda:0')
tensor(0.9532, device='cuda:0')
tensor(0.9531, device='cuda:0')
tensor(0.9533, device='cuda:0')
tensor(0.9521, device='cuda:0')
tensor(0.9563, device='cuda:0')
tensor(0.9536, device='cuda:0')
tensor(0.9535, device='cuda:0')
tensor(0.9511, device='cuda:0')
tensor(0.9504, device='cuda:0')
tensor(0.9507, device='cuda:0')
tensor(0.9520, device='cuda:0')
tensor(0.9515, device='cuda:0')
0.6666666666666666


In [67]:
print(pred_labels)
print(true_labels)

[1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1]
[np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0)]


In [68]:
print(anomalies[:9])


[('wormed', tensor(0.9518, device='cuda:0')), ('no wormed', tensor(0.9549, device='cuda:0')), ('wormed', tensor(0.9533, device='cuda:0')), ('wormed', tensor(0.9528, device='cuda:0')), ('wormed', tensor(0.9522, device='cuda:0')), ('wormed', tensor(0.9532, device='cuda:0')), ('wormed', tensor(0.9531, device='cuda:0')), ('wormed', tensor(0.9533, device='cuda:0')), ('wormed', tensor(0.9521, device='cuda:0'))]


In [69]:
print(anomalies[9:])


[('no wormed', tensor(0.9563, device='cuda:0')), ('wormed', tensor(0.9536, device='cuda:0')), ('wormed', tensor(0.9535, device='cuda:0')), ('wormed', tensor(0.9511, device='cuda:0')), ('wormed', tensor(0.9504, device='cuda:0')), ('wormed', tensor(0.9507, device='cuda:0')), ('wormed', tensor(0.9520, device='cuda:0')), ('wormed', tensor(0.9515, device='cuda:0'))]
