In [1]:
import torch
from torch import nn

from Dataset import Image_dataset
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from torch.utils.data import WeightedRandomSampler

from Critic import Discriminator
import torch.optim as optim
from PIL import Image
import numpy as np
from torchvision import transforms

from config import *

In [2]:
ds = Image_dataset("Dataset")

labels = []
for i in range(len(ds)):
    _, label = ds[i]
    labels.append(ds[i][1])

class_counts = Counter(labels)
total_samples = len(labels)

weight_per_class = {cls: total_samples/count for cls, count in class_counts.items()}
weights = [weight_per_class[label] for label in labels]

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
dl = DataLoader(ds, batch_size = 5, sampler=sampler)


In [3]:
critic = Discriminator(in_channels=3)
optimizer = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()

In [72]:
for epoch in range(EPOCHS):
    batch_loss = 0
    for data in dl:

        x = data[0].to(DEVICE)
        x = x.float()
        critic_output = critic(x.permute(0, -1, 1, 2))
        # logits = critic_output.view(critic_output.shape[0], -1).mean(1)

        logits = critic_output.squeeze(1)
        
        labels = data[1]
        targets = torch.tensor(
            [1.0 if label == "Not Defective" else 0.0 for label in labels],
            dtype=torch.float32,
            device=critic_output.device
        )

        targets = torch.stack([torch.ones(logits.shape[1:]) if i == 1. else torch.zeros(logits.shape[1:]) for i in targets], 0).to(critic_output.device)
        
        loss = BCE(logits, targets)
        batch_loss += loss.item()
        critic.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch: {epoch} | Batch Loss: {batch_loss/len(dl):.4f}")

Epoch: 0 | Batch Loss: 0.1753
Epoch: 1 | Batch Loss: 0.2671
Epoch: 2 | Batch Loss: 0.4287


In [71]:
torch.save(critic.state_dict(), "critic_model.pth")

In [76]:
image_path = "/Users/mohamedmafaz/Desktop/CountAI/Dataset/2.jpg"

transform = transforms.Compose([
    transforms.Resize((256, 256)),  
    transforms.ToTensor(),          
])

image = Image.open(image_path).convert("RGB")  
tensor_image = transform(image).unsqueeze(0)  

result = critic(tensor_image)
mat_min = result.min()
mat_max = result.max()
mat_normalized = (result - mat_min) / (mat_max - mat_min)
mat_normalized = mat_normalized.view(tensor_image.shape[0], -1).mean(1)

if mat_normalized > 0.5:
    print("Not Defective")
else:
    print("Defective")

Not Defective


In [77]:
mat_normalized

tensor([0.7174], grad_fn=<MeanBackward1>)