In [42]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from PIL import ImageFile
from PIL import Image

ImageFile.LOAD_TRUNCATED_IMAGES = True

def safe_loader(path):
    try:
        with Image.open(path) as img:
            return img.convert('RGB')
    except:
        return None

class SafeImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if sample is None:
            # Skip or handle unreadable images
            return None, None
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

def safe_collate_fn(batch):
    # Exclude unreadable samples
    batch = [(x, y) for (x, y) in batch if x is not None and y is not None]
    if len(batch) == 0:
        return None, None
    return default_collate(batch)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load dataset
data_dir = "/kaggle/input/dataset/OC Dataset kaggle new"
train_data = SafeImageFolder(
    root=data_dir + '/train',
    transform=transform,
    loader=safe_loader
)
validation_data = SafeImageFolder(
    root=data_dir + '/valid',
    transform=transform,
    loader=safe_loader
)

trainloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=safe_collate_fn)
valloader = DataLoader(validation_data, batch_size=32, shuffle=False, collate_fn=safe_collate_fn)

# Define model
model = models.resnet18(pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 1),
    nn.Sigmoid()
)
model = model.to(device)

# Define loss and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 10
best_val_loss = float('inf')

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        
        # Binary prediction threshold at 0.5
        preds = (outputs >= 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / len(train_data)
    epoch_acc = correct / total if total > 0 else 0
    print(f"Epoch {epoch+1}/{epochs}, Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.4f}")
    
    # Validation loop
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in valloader:
            inputs, labels = inputs.to(device), labels.to(device).float().unsqueeze(1)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)

            preds = (outputs >= 0.5).float()
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_loss /= len(validation_data)
    val_acc = val_correct / val_total if val_total > 0 else 0
    print(f"Epoch {epoch+1}/{epochs}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

print("Training complete.")


Epoch 1/10, Training Loss: 0.5808, Training Accuracy: 0.8439
Epoch 1/10, Validation Loss: 0.7108, Validation Accuracy: 0.6633
Epoch 2/10, Training Loss: 0.5606, Training Accuracy: 0.8970
Epoch 2/10, Validation Loss: 0.7898, Validation Accuracy: 0.5255
Epoch 3/10, Training Loss: 0.5490, Training Accuracy: 0.9167
Epoch 3/10, Validation Loss: 0.7111, Validation Accuracy: 0.6071
Epoch 4/10, Training Loss: 0.5504, Training Accuracy: 0.9061
Epoch 4/10, Validation Loss: 0.6382, Validation Accuracy: 0.7602
Epoch 5/10, Training Loss: 0.5435, Training Accuracy: 0.9242
Epoch 5/10, Validation Loss: 0.6561, Validation Accuracy: 0.7449
Epoch 6/10, Training Loss: 0.5546, Training Accuracy: 0.8955
Epoch 6/10, Validation Loss: 0.6489, Validation Accuracy: 0.6684
Epoch 7/10, Training Loss: 0.5536, Training Accuracy: 0.8879
Epoch 7/10, Validation Loss: 0.6569, Validation Accuracy: 0.7245
Epoch 8/10, Training Loss: 0.5431, Training Accuracy: 0.9167
Epoch 8/10, Validation Loss: 0.5902, Validation Accuracy: