In [1]:
import torch
from torch import nn

from Dataset import Image_dataset
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from torch.utils.data import WeightedRandomSampler

from Critic import Discriminator
import torch.optim as optim
from PIL import Image
import numpy as np
from torchvision import transforms

from config import *

In [2]:
ds = Image_dataset("Dataset")

labels = []
for i in range(len(ds)):
    _, label = ds[i]
    labels.append(ds[i][1])

class_counts = Counter(labels)
total_samples = len(labels)

weight_per_class = {cls: total_samples/count for cls, count in class_counts.items()}
weights = [weight_per_class[label] for label in labels]

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
dl = DataLoader(ds, batch_size = 5, sampler=sampler)


In [3]:
critic = Discriminator(in_channels=3)
optimizer = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()

In [None]:
# Set the critic model to training mode
critic.train()
for epoch in range(EPOCHS):
    batch_loss = 0
    for data in dl:
        # Move input images to DEVICE and convert to float
        x = data[0].to(DEVICE)
        x = x.float()
        # Forward pass through the critic, permute to (batch, channels, height, width)
        critic_output = critic(x.permute(0, -1, 1, 2))
        # Remove the channel dimension if it's 1
        logits = critic_output.squeeze(1)
        
        # Get the labels from the batch
        labels = data[1]
        # Convert string labels to float targets: 1.0 for "Not Defective", 0.0 for "Defective"
        targets = torch.tensor(
            [1.0 if label == "Not Defective" else 0.0 for label in labels],
            dtype=torch.float32,
            device=critic_output.device
        )
        # Expand targets to match the shape of logits for pixel-wise loss
        targets = torch.stack([torch.ones(logits.shape[1:]) if i == 1. else torch.zeros(logits.shape[1:]) for i in targets], 0).to(critic_output.device)
        
        # Compute the binary cross-entropy loss with logits
        loss = BCE(logits, targets)
        batch_loss += loss.item()
        # Zero gradients, backpropagate, and update weights
        critic.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Print average batch loss for the epoch
    print(f"Epoch: {epoch} | Batch Loss: {batch_loss/len(dl):.4f}")

Epoch: 0 | Batch Loss: 0.5361
Epoch: 1 | Batch Loss: 0.4786
Epoch: 2 | Batch Loss: 0.4949
Epoch: 3 | Batch Loss: 0.3676
Epoch: 4 | Batch Loss: 0.4039
Epoch: 5 | Batch Loss: 0.5561
Epoch: 6 | Batch Loss: 0.4663
Epoch: 7 | Batch Loss: 0.3249
Epoch: 8 | Batch Loss: 0.2548
Epoch: 9 | Batch Loss: 0.5269


In [17]:
torch.save(critic.state_dict(), "critic.pth")

In [None]:
critic = Discriminator(in_channels=3)
critic.load_state_dict(torch.load("critic.pth", map_location=torch.device('cpu'), weights_only=True))
critic.eval()

In [23]:
image_path = "/Users/mohamedmafaz/Desktop/CountAI/Dataset/59.jpg"

transform = transforms.Compose([
    transforms.Resize((256, 256)),  
    transforms.ToTensor(),          
])

image = Image.open(image_path).convert("RGB")  
tensor_image = transform(image).unsqueeze(0)  

result = critic(tensor_image)
mat_min = result.min()
mat_max = result.max()
mat_normalized = (result - mat_min) / (mat_max - mat_min)
mat_normalized = mat_normalized.view(tensor_image.shape[0], -1).mean(1)

if mat_normalized > 0.5:
    print("Not Defective")
else:
    print("Defective")

Defective


In [24]:
mat_normalized

tensor([0.4401], grad_fn=<MeanBackward1>)